2424from .subject import Subject
2525from .user import User
2626
27- OnConflictOpts = Enum ('OnConflictOpts' , ['ERROR' , 'OVERWRITE' , 'SKIP' ])
27+
28+ class OnConflictOpts (str , Enum ):
29+ ERROR = "error"
30+ OVERWRITE = "overwrite"
31+ SKIP = "skip"
32+
33+
34+ def _normalize_on_conflict (on_conflict : OnConflictOpts | str ) -> OnConflictOpts :
35+ if isinstance (on_conflict , OnConflictOpts ):
36+ return on_conflict
37+ if isinstance (on_conflict , str ):
38+ try :
39+ return OnConflictOpts (on_conflict .lower ())
40+ except ValueError as exc :
41+ raise ValueError ("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'." ) from exc
42+ raise ValueError ("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'." )
2843
2944
3045class Database :
@@ -59,7 +74,8 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER
5974
6075 self .logger .info (f"Added User with ID { user_id } to the database." )
6176
62- def delete_user (self , user_id : str , on_conflict : OnConflictOpts = OnConflictOpts .ERROR ) -> None :
77+ def delete_user (self , user_id : str , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ) -> None :
78+ on_conflict = _normalize_on_conflict (on_conflict )
6379 # Check if the user ID already exists in the database
6480 user_ids = self .get_user_ids ()
6581
@@ -83,7 +99,8 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts
8399
84100 self .logger .info (f"Removed Sonication User with ID { user_id } from the database." )
85101
86- def write_protocol (self , protocol : Protocol , on_conflict : OnConflictOpts = OnConflictOpts .ERROR ):
102+ def write_protocol (self , protocol : Protocol , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
103+ on_conflict = _normalize_on_conflict (on_conflict )
87104 # Check if the sonication protocol ID already exists in the database
88105 protocol_id = protocol .id
89106 protocol_ids = self .get_protocol_ids ()
@@ -110,7 +127,8 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon
110127
111128 self .logger .info (f"Added Sonication Protocol with ID { protocol_id } to the database." )
112129
113- def delete_protocol (self , protocol_id : str , on_conflict : OnConflictOpts = OnConflictOpts .ERROR ):
130+ def delete_protocol (self , protocol_id : str , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
131+ on_conflict = _normalize_on_conflict (on_conflict )
114132 # Check if the sonication protocol ID already exists in the database
115133 protocol_ids = self .get_protocol_ids ()
116134
@@ -134,7 +152,8 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf
134152
135153 self .logger .info (f"Removed Sonication Protocol with ID { protocol_id } from the database." )
136154
137- def write_session (self , subject :Subject , session :Session , on_conflict = OnConflictOpts .ERROR ):
155+ def write_session (self , subject :Subject , session :Session , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
156+ on_conflict = _normalize_on_conflict (on_conflict )
138157 # Generate session ID
139158 session_id = session .id
140159
@@ -199,14 +218,15 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict
199218
200219 self .logger .info (f"Added session with ID { session_id } for subject { subject .id } to the database." )
201220
202- def delete_session (self , subject_id : str , session_id : str , on_conflict : OnConflictOpts = OnConflictOpts .ERROR ):
221+ def delete_session (self , subject_id : str , session_id : str , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
203222 """Delete a session and its associated data from the database.
204223
205224 Args:
206225 subject_id: ID of the subject the session belongs to
207226 session_id: ID of the session to delete
208227 on_conflict: Behavior when session doesn't exist ('error' or 'skip')
209228 """
229+ on_conflict = _normalize_on_conflict (on_conflict )
210230 # Check if the session ID exists in the database for this subject
211231 session_ids = self .get_session_ids (subject_id )
212232
@@ -229,7 +249,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli
229249
230250 self .logger .info (f"Removed session with ID { session_id } from the database." )
231251
232- def write_run (self , run :Run , session :Session = None , protocol :Protocol = None , on_conflict = OnConflictOpts .ERROR ):
252+ def write_run (self , run :Run , session :Session = None , protocol :Protocol = None , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
233253 """Write a run with a snapshot of session and a snapshot of protocol if provided
234254
235255 Args:
@@ -240,6 +260,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
240260 Returns:
241261 None: This method does not return a value
242262 """
263+ on_conflict = _normalize_on_conflict (on_conflict )
243264 # Check whether the run already exist in the session
244265 run_ids = self .get_run_ids (session .subject_id , session .id )
245266
@@ -270,7 +291,8 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
270291 # Write snapshot of the protocol
271292 protocol .to_file (run_metadata_filepath .parent / f'{ run .id } _protocol_snapshot.json' )
272293
273- def write_subject (self , subject , on_conflict = OnConflictOpts .ERROR ):
294+ def write_subject (self , subject , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
295+ on_conflict = _normalize_on_conflict (on_conflict )
274296 subject_id = subject .id
275297 subject_ids = self .get_subject_ids ()
276298
@@ -305,7 +327,7 @@ def write_transducer(
305327 transducer ,
306328 registration_surface_model_filepath : PathLike | None = None ,
307329 transducer_body_model_filepath : PathLike | None = None ,
308- on_conflict : OnConflictOpts = OnConflictOpts .ERROR ,
330+ on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ,
309331 ) -> None :
310332 """ Writes a transducer object to database and copies the affiliated transducer data files to the database if provided. When a transducer that is already present in the database is being re-written,
311333 the associated model data files do not need to be provided if they have previously been added to the database.
@@ -316,6 +338,7 @@ def write_transducer(
316338 Returns:
317339 None: This method does not return a value
318340 """
341+ on_conflict = _normalize_on_conflict (on_conflict )
319342 transducer_id = transducer .id
320343 transducer_ids = self .get_transducer_ids ()
321344
@@ -365,7 +388,8 @@ def write_transducer(
365388
366389 self .logger .info (f"Added transducer with ID { transducer_id } to the database." )
367390
368- def write_volume (self , subject_id , volume_id , volume_name , volume_data_filepath , on_conflict = OnConflictOpts .ERROR ):
391+ def write_volume (self , subject_id , volume_id , volume_name , volume_data_filepath , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
392+ on_conflict = _normalize_on_conflict (on_conflict )
369393 if not Path (volume_data_filepath ).exists ():
370394 raise ValueError (f'Volume data filepath does not exist: { volume_data_filepath } ' )
371395
@@ -421,10 +445,11 @@ def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath,
421445 if temp_nifti_path is not None and temp_nifti_path .exists ():
422446 temp_nifti_path .unlink ()
423447
424- def write_photocollection (self , subject_id , session_id , reference_number : str , photo_paths : List [PathLike ], on_conflict = OnConflictOpts .ERROR ):
448+ def write_photocollection (self , subject_id , session_id , reference_number : str , photo_paths : List [PathLike ], on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
425449 """ Writes a photocollection to database and copies the associated
426450 photos into the database, specified by the subject, session, and
427451 reference_number of the photocollection."""
452+ on_conflict = _normalize_on_conflict (on_conflict )
428453
429454 photocollection_dir = Path (self .get_session_dir (subject_id , session_id )) / 'photocollections' / reference_number
430455
@@ -457,10 +482,11 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p
457482
458483 self .logger .info (f"Added photocollection with reference number { reference_number } for session { session_id } to the database." )
459484
460- def write_photoscan (self , subject_id , session_id , photoscan : Photoscan , model_data_filepath : str | None = None , texture_data_filepath : str | None = None , mtl_data_filepath : str | None = None , on_conflict = OnConflictOpts .ERROR ):
485+ def write_photoscan (self , subject_id , session_id , photoscan : Photoscan , model_data_filepath : str | None = None , texture_data_filepath : str | None = None , mtl_data_filepath : str | None = None , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
461486 """ Writes a photoscan object to database and copies the associated data filepaths into the database.
462487 While the model data file is required, the associated texture and .mtl files are optional and can be provided if present.
463488 When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """
489+ on_conflict = _normalize_on_conflict (on_conflict )
464490
465491 photoscan_ids = self .get_photoscan_ids (subject_id , session_id )
466492 if photoscan .id in photoscan_ids :
@@ -518,7 +544,8 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da
518544
519545 self .logger .info (f"Added photoscan with ID { photoscan .id } for session { session_id } to the database." )
520546
521- def write_solution (self , session :Session , solution :Solution , on_conflict : OnConflictOpts = OnConflictOpts .ERROR ):
547+ def write_solution (self , session :Session , solution :Solution , on_conflict : OnConflictOpts | str = OnConflictOpts .ERROR ):
548+ on_conflict = _normalize_on_conflict (on_conflict )
522549 solution_ids = self .get_solution_ids (session .subject_id , session .id )
523550
524551 if solution .id in solution_ids :
0 commit comments