Skip to content

Commit 72ed796

Browse files
authored
fix: Return UTC-aware datetimes from unmarshalling (#809)
Previously, datetime fields were unmarshalled using datetime.fromtimestamp(value) which returns a naive datetime in the server's local timezone. This caused: - Non-deterministic behavior depending on server timezone - Inability to compare retrieved datetimes with timezone-aware datetimes - Time jumps around daylight savings transitions This fix changes unmarshalling to use datetime.fromtimestamp(value, timezone.utc) which returns a UTC-aware datetime. This follows the standard ORM pattern of storing UTC and returning UTC-aware datetimes. BREAKING CHANGE: Retrieved datetime fields are now UTC-aware instead of naive local time. Code that compared retrieved datetimes with naive datetimes will need to either: 1. Make the comparison datetime UTC-aware, or 2. Use .timestamp() for comparison Fixes #807 (Return UTC-aware datetimes from unmarshalling) * style: Format test file with ruff * style: Format model.py with ruff * fix: Make test use UTC-aware datetime after fix * fix: preserve date round-trips across timezones * chore: keep sync output formatter-clean in CI * chore: align lint with current ruff rules * fix: normalize date query timestamps to UTC
1 parent dee6a3f commit 72ed796

27 files changed

Lines changed: 292 additions & 74 deletions

aredis_om/connections.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import redis
44

5+
56
URL = os.environ.get("REDIS_OM_URL", None)
67

78

aredis_om/model/encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from pydantic import BaseModel
3535

36+
3637
try:
3738
from pydantic.deprecated.json import ENCODERS_BY_TYPE
3839
from pydantic_core import PydanticUndefined

aredis_om/model/migrations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SchemaMigrator,
1919
)
2020

21+
2122
__all__ = [
2223
# Data migrations
2324
"BaseMigration",

aredis_om/model/migrations/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
from .base import BaseMigration, DataMigrationError
99
from .migrator import DataMigrator
1010

11+
1112
__all__ = ["BaseMigration", "DataMigrationError", "DataMigrator"]

aredis_om/model/migrations/data/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
from typing import Any, Dict, List
1111

12+
1213
try:
1314
import psutil
1415
except ImportError:

aredis_om/model/migrations/data/builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
DatetimeFieldMigration,
1212
)
1313

14+
1415
__all__ = ["DatetimeFieldMigration", "DatetimeFieldDetector", "ConversionFailureMode"]

aredis_om/model/migrations/data/builtin/datetime_migration.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ..base import BaseMigration, DataMigrationError
1818

19+
1920
log = logging.getLogger(__name__)
2021

2122

@@ -180,9 +181,9 @@ def __init__(self):
180181
self.converted_fields = 0
181182
self.skipped_fields = 0
182183
self.failed_conversions = 0
183-
self.errors: List[Tuple[str, str, str, Exception]] = (
184-
[]
185-
) # (key, field, value, error)
184+
self.errors: List[
185+
Tuple[str, str, str, Exception]
186+
] = [] # (key, field, value, error)
186187

187188
def add_conversion_error(self, key: str, field: str, value: Any, error: Exception):
188189
"""Record a conversion error."""
@@ -393,7 +394,9 @@ async def save_progress(
393394
}
394395

395396
await self.redis.set(
396-
self.state_key, json.dumps(state_data), ex=86400 # Expire after 24 hours
397+
self.state_key,
398+
json.dumps(state_data),
399+
ex=86400, # Expire after 24 hours
397400
)
398401

399402
async def load_progress(self) -> Dict[str, Any]:

aredis_om/model/migrations/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .legacy_migrator import MigrationAction, MigrationError, Migrator, SchemaDetector
1313
from .migrator import SchemaMigrator
1414

15+
1516
__all__ = [
1617
# Primary API
1718
"BaseSchemaMigration",

aredis_om/model/migrations/schema/legacy_migrator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import redis
2020

21+
2122
log = logging.getLogger(__name__)
2223

2324

@@ -52,7 +53,8 @@ def import_submodules(root_module_name: str):
5253
)
5354

5455
for loader, module_name, is_pkg in pkgutil.walk_packages(
55-
root_module.__path__, root_module.__name__ + "." # type: ignore
56+
root_module.__path__,
57+
root_module.__name__ + ".", # type: ignore
5658
):
5759
importlib.import_module(module_name)
5860

aredis_om/model/model.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@
2222
Type,
2323
TypeVar,
2424
Union,
25-
)
26-
from typing import get_args as typing_get_args
27-
from typing import (
2825
no_type_check,
2926
)
27+
from typing import get_args as typing_get_args
3028

3129
from more_itertools import ichunked
3230
from pydantic import BaseModel
3331

32+
3433
try:
3534
from pydantic import ConfigDict, TypeAdapter, field_validator
3635

@@ -73,6 +72,7 @@
7372
from .token_escaper import TokenEscaper
7473
from .types import Coordinates, CoordinateType, GeoFilter
7574

75+
7676
model_registry = {}
7777
_T = TypeVar("_T")
7878
Model = TypeVar("Model", bound="RedisModel")
@@ -115,8 +115,11 @@ def convert_datetime_to_timestamp(obj):
115115
elif isinstance(obj, datetime.datetime):
116116
return obj.timestamp()
117117
elif isinstance(obj, datetime.date):
118-
# Convert date to datetime at midnight and get timestamp
119-
dt = datetime.datetime.combine(obj, datetime.time.min)
118+
# Date values represent calendar days, so normalize to UTC midnight
119+
# to avoid timezone-dependent day shifts on round-trip conversion.
120+
dt = datetime.datetime.combine(
121+
obj, datetime.time.min, tzinfo=datetime.timezone.utc
122+
)
120123
return dt.timestamp()
121124
else:
122125
return obj
@@ -138,7 +141,9 @@ def convert_timestamp_to_datetime(obj, model_fields):
138141
# For Optional[T] which is Union[T, None], get the non-None type
139142
args = getattr(field_type, "__args__", ())
140143
non_none_types = [
141-
arg for arg in args if arg is not type(None) # noqa: E721
144+
arg
145+
for arg in args
146+
if arg is not type(None) # noqa: E721
142147
]
143148
if len(non_none_types) == 1:
144149
field_type = non_none_types[0]
@@ -150,8 +155,13 @@ def convert_timestamp_to_datetime(obj, model_fields):
150155
try:
151156
if isinstance(value, str):
152157
value = float(value)
153-
# Use fromtimestamp to preserve local timezone behavior
154-
dt = datetime.datetime.fromtimestamp(value)
158+
# Return UTC-aware datetime for consistency.
159+
# Timestamps are always UTC-referenced, so we return
160+
# UTC-aware datetimes. Users can convert to their
161+
# preferred timezone with dt.astimezone(tz).
162+
dt = datetime.datetime.fromtimestamp(
163+
value, datetime.timezone.utc
164+
)
155165
# If the field is specifically a date, convert to date
156166
if field_type is datetime.date:
157167
result[key] = dt.date()
@@ -255,7 +265,9 @@ def convert_base64_to_bytes(obj, model_fields):
255265
# For Optional[T] which is Union[T, None], get the non-None type
256266
args = getattr(field_type, "__args__", ())
257267
non_none_types = [
258-
arg for arg in args if arg is not type(None) # noqa: E721
268+
arg
269+
for arg in args
270+
if arg is not type(None) # noqa: E721
259271
]
260272
if len(non_none_types) == 1:
261273
field_type = non_none_types[0]
@@ -636,10 +648,10 @@ def embedded(cls):
636648

637649
def is_supported_container_type(typ: Optional[type]) -> bool:
638650
# TODO: Wait, why don't we support indexing sets?
639-
if typ == list or typ == tuple or typ == Literal:
651+
if typ is list or typ is tuple or typ is Literal:
640652
return True
641653
unwrapped = get_origin(typ)
642-
return unwrapped == list or unwrapped == tuple or unwrapped == Literal
654+
return unwrapped is list or unwrapped is tuple or unwrapped is Literal
643655

644656

645657
def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
@@ -1056,7 +1068,7 @@ def _validate_deep_field_path(self, field_path: str):
10561068
field_type, RedisModel
10571069
):
10581070
current_model = field_type
1059-
elif field_type == dict:
1071+
elif field_type is dict:
10601072
# Dict fields - we can't validate nested paths, just accept them
10611073
return
10621074
else:
@@ -1089,7 +1101,7 @@ def _validate_deep_field_path(self, field_path: str):
10891101
field_type, RedisModel
10901102
):
10911103
current_model = field_type
1092-
elif field_type == dict:
1104+
elif field_type is dict:
10931105
return # Can't validate further into dict
10941106
else:
10951107
raise QueryNotSupportedError(
@@ -1174,18 +1186,18 @@ def _convert_projected_fields(self, raw_data: Dict[str, str]) -> Dict[str, Any]:
11741186
field_type = getattr(field_info, "type_", str)
11751187

11761188
# Handle common type conversions directly for efficiency
1177-
if field_type == int:
1189+
if field_type is int:
11781190
converted_data[field_name] = int(raw_value)
1179-
elif field_type == float:
1191+
elif field_type is float:
11801192
converted_data[field_name] = float(raw_value)
1181-
elif field_type == bool:
1193+
elif field_type is bool:
11821194
# Redis may store bool as "True"/"False" or "1"/"0"
11831195
converted_data[field_name] = raw_value.lower() in (
11841196
"true",
11851197
"1",
11861198
"yes",
11871199
)
1188-
elif field_type == str:
1200+
elif field_type is str:
11891201
converted_data[field_name] = raw_value
11901202
else:
11911203
# For complex types, keep as string (could be enhanced later)
@@ -1231,7 +1243,7 @@ def _has_complex_projected_fields(self) -> bool:
12311243
field_type = getattr(field_info, "annotation", None)
12321244

12331245
# Check for dict fields
1234-
if field_type == dict:
1246+
if field_type is dict:
12351247
return True
12361248

12371249
# Check for embedded models (subclasses of RedisModel)
@@ -1524,8 +1536,7 @@ def expand_tag_value(value):
15241536
return "|".join([escaper.escape(str(v)) for v in value])
15251537
except TypeError:
15261538
log.debug(
1527-
"Escaping single non-iterable value used for an IN or "
1528-
"NOT_IN query: %s",
1539+
"Escaping single non-iterable value used for an IN or NOT_IN query: %s",
15291540
value,
15301541
)
15311542
return escaper.escape(str(value))
@@ -1571,8 +1582,10 @@ def convert_numeric_value(v):
15711582
if isinstance(v, datetime.date) and not isinstance(
15721583
v, datetime.datetime
15731584
):
1574-
# Convert date to datetime at midnight
1575-
v = datetime.datetime.combine(v, datetime.time.min)
1585+
# Use UTC midnight so query conversion matches storage conversion.
1586+
v = datetime.datetime.combine(
1587+
v, datetime.time.min, tzinfo=datetime.timezone.utc
1588+
)
15761589
v = v.timestamp()
15771590
return v
15781591

@@ -3352,9 +3365,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
33523365
field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR
33533366
)
33543367
if getattr(field_info, "full_text_search", False) is True:
3355-
schema = (
3356-
f"{name} TAG SEPARATOR {separator} " f"{name} AS {name}_fts TEXT"
3357-
)
3368+
schema = f"{name} TAG SEPARATOR {separator} {name} AS {name}_fts TEXT"
33583369
else:
33593370
schema = f"{name} TAG SEPARATOR {separator}"
33603371
elif issubclass(typ, RedisModel):

0 commit comments

Comments
 (0)