1717from contextlib import asynccontextmanager
1818import copy
1919from datetime import datetime
20- from datetime import timezone
2120import logging
2221from typing import Any
2322from typing import AsyncIterator
5958from .schemas .v1 import StorageMetadata
6059from .schemas .v1 import StorageSession as StorageSessionV1
6160from .schemas .v1 import StorageUserState as StorageUserStateV1
61+ from .schemas .shared import update_time_from_timestamp
6262from .session import Session
6363from .state import State
6464
@@ -458,11 +458,10 @@ async def create_session(
458458 storage_user_state .state = storage_user_state .state | user_state_delta
459459
460460 # Store the session
461- now = datetime .fromtimestamp (platform_time .get_time (), tz = timezone .utc )
462- is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
463- is_postgresql = self .db_engine .dialect .name == _POSTGRESQL_DIALECT
464- if is_sqlite or is_postgresql :
465- now = now .replace (tzinfo = None )
461+ dialect_name = self .db_engine .dialect .name
462+ now = update_time_from_timestamp (
463+ platform_time .get_time (), dialect_name
464+ )
466465
467466 storage_session = schema .StorageSession (
468467 app_name = app_name ,
@@ -480,7 +479,7 @@ async def create_session(
480479 storage_app_state .state , storage_user_state .state , session_state
481480 )
482481 session = storage_session .to_session (
483- state = merged_state , is_sqlite = is_sqlite , is_postgresql = is_postgresql
482+ state = merged_state , dialect_name = dialect_name
484483 )
485484 return session
486485
@@ -498,8 +497,7 @@ async def get_session(
498497 # 2. Get all the events based on session id and filtering config
499498 # 3. Convert and return the session
500499 schema = self ._get_schema_classes ()
501- is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
502- is_postgresql = self .db_engine .dialect .name == _POSTGRESQL_DIALECT
500+ dialect_name = self .db_engine .dialect .name
503501 async with self ._rollback_on_exception_session (
504502 read_only = True
505503 ) as sql_session :
@@ -548,8 +546,7 @@ async def get_session(
548546 session = storage_session .to_session (
549547 state = merged_state ,
550548 events = events ,
551- is_sqlite = is_sqlite ,
552- is_postgresql = is_postgresql ,
549+ dialect_name = dialect_name ,
553550 )
554551 return session
555552
@@ -595,17 +592,15 @@ async def list_sessions(
595592 user_states_map [storage_user_state .user_id ] = storage_user_state .state
596593
597594 sessions = []
598- is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
599- is_postgresql = self .db_engine .dialect .name == _POSTGRESQL_DIALECT
595+ dialect_name = self .db_engine .dialect .name
600596 for storage_session in results :
601597 session_state = storage_session .state
602598 user_state = user_states_map .get (storage_session .user_id , {})
603599 merged_state = _merge_state (app_state , user_state , session_state )
604600 sessions .append (
605601 storage_session .to_session (
606602 state = merged_state ,
607- is_sqlite = is_sqlite ,
608- is_postgresql = is_postgresql ,
603+ dialect_name = dialect_name ,
609604 )
610605 )
611606 return ListSessionsResponse (sessions = sessions )
@@ -641,8 +636,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
641636 # 2. Update session attributes based on event config.
642637 # 3. Store the new event.
643638 schema = self ._get_schema_classes ()
644- is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
645- is_postgresql = self .db_engine .dialect .name == _POSTGRESQL_DIALECT
639+ dialect_name = self .db_engine .dialect .name
646640 use_row_level_locking = self ._supports_row_level_locking ()
647641
648642 state_delta = (
@@ -672,9 +666,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
672666 storage_session = storage_session_result .scalars ().one_or_none ()
673667 if storage_session is None :
674668 raise ValueError (f"Session { session .id } not found." )
675- storage_update_time = storage_session .get_update_timestamp (
676- is_sqlite , is_postgresql
677- )
669+ storage_update_time = storage_session .get_update_timestamp (dialect_name )
678670 storage_update_marker = storage_session .get_update_marker ()
679671
680672 storage_app_state = await _select_required_state (
@@ -740,20 +732,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
740732 storage_session .state | state_deltas ["session" ]
741733 )
742734
743- if is_sqlite or is_postgresql :
744- update_time = datetime .fromtimestamp (
745- event .timestamp , timezone .utc
746- ).replace (tzinfo = None )
747- else :
748- update_time = datetime .fromtimestamp (event .timestamp , timezone .utc )
749- storage_session .update_time = update_time
735+ storage_session .update_time = update_time_from_timestamp (
736+ event .timestamp , dialect_name
737+ )
750738 sql_session .add (schema .StorageEvent .from_event (session , event ))
751739
752740 await sql_session .commit ()
753741
754742 # Update timestamp with commit time
755743 session .last_update_time = storage_session .get_update_timestamp (
756- is_sqlite , is_postgresql
744+ dialect_name
757745 )
758746 session ._storage_update_marker = storage_session .get_update_marker ()
759747
0 commit comments