Skip to content

Commit 2f9e682

Browse files
Merge branch 'main' into p/init_install_cleanup
2 parents fb77c32 + 68b00c9 commit 2f9e682

2 files changed

Lines changed: 79 additions & 13 deletions

File tree

src/openlifu/db/database.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,22 @@
2424
from .subject import Subject
2525
from .user import User
2626

27-
OnConflictOpts = Enum('OnConflictOpts', ['ERROR', 'OVERWRITE', 'SKIP'])
27+
28+
class OnConflictOpts(str, Enum):
29+
ERROR = "error"
30+
OVERWRITE = "overwrite"
31+
SKIP = "skip"
32+
33+
34+
def _normalize_on_conflict(on_conflict: OnConflictOpts | str) -> OnConflictOpts:
35+
if isinstance(on_conflict, OnConflictOpts):
36+
return on_conflict
37+
if isinstance(on_conflict, str):
38+
try:
39+
return OnConflictOpts(on_conflict.lower())
40+
except ValueError as exc:
41+
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.") from exc
42+
raise ValueError("Invalid 'on_conflict' option. Use 'error', 'overwrite', or 'skip'.")
2843

2944

3045
class Database:
@@ -59,7 +74,8 @@ def write_user(self, user: User, on_conflict: OnConflictOpts = OnConflictOpts.ER
5974

6075
self.logger.info(f"Added User with ID {user_id} to the database.")
6176

62-
def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR) -> None:
77+
def delete_user(self, user_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR) -> None:
78+
on_conflict = _normalize_on_conflict(on_conflict)
6379
# Check if the user ID already exists in the database
6480
user_ids = self.get_user_ids()
6581

@@ -83,7 +99,8 @@ def delete_user(self, user_id: str, on_conflict: OnConflictOpts = OnConflictOpts
8399

84100
self.logger.info(f"Removed Sonication User with ID {user_id} from the database.")
85101

86-
def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
102+
def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
103+
on_conflict = _normalize_on_conflict(on_conflict)
87104
# Check if the sonication protocol ID already exists in the database
88105
protocol_id = protocol.id
89106
protocol_ids = self.get_protocol_ids()
@@ -110,7 +127,8 @@ def write_protocol(self, protocol: Protocol, on_conflict: OnConflictOpts = OnCon
110127

111128
self.logger.info(f"Added Sonication Protocol with ID {protocol_id} to the database.")
112129

113-
def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
130+
def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
131+
on_conflict = _normalize_on_conflict(on_conflict)
114132
# Check if the sonication protocol ID already exists in the database
115133
protocol_ids = self.get_protocol_ids()
116134

@@ -134,7 +152,8 @@ def delete_protocol(self, protocol_id: str, on_conflict: OnConflictOpts = OnConf
134152

135153
self.logger.info(f"Removed Sonication Protocol with ID {protocol_id} from the database.")
136154

137-
def write_session(self, subject:Subject, session:Session, on_conflict=OnConflictOpts.ERROR):
155+
def write_session(self, subject:Subject, session:Session, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
156+
on_conflict = _normalize_on_conflict(on_conflict)
138157
# Generate session ID
139158
session_id = session.id
140159

@@ -199,14 +218,15 @@ def write_session(self, subject:Subject, session:Session, on_conflict=OnConflict
199218

200219
self.logger.info(f"Added session with ID {session_id} for subject {subject.id} to the database.")
201220

202-
def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts = OnConflictOpts.ERROR):
221+
def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
203222
"""Delete a session and its associated data from the database.
204223
205224
Args:
206225
subject_id: ID of the subject the session belongs to
207226
session_id: ID of the session to delete
208227
on_conflict: Behavior when session doesn't exist ('error' or 'skip')
209228
"""
229+
on_conflict = _normalize_on_conflict(on_conflict)
210230
# Check if the session ID exists in the database for this subject
211231
session_ids = self.get_session_ids(subject_id)
212232

@@ -229,7 +249,7 @@ def delete_session(self, subject_id: str, session_id: str, on_conflict: OnConfli
229249

230250
self.logger.info(f"Removed session with ID {session_id} from the database.")
231251

232-
def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict=OnConflictOpts.ERROR):
252+
def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
233253
"""Write a run with a snapshot of session and a snapshot of protocol if provided
234254
235255
Args:
@@ -240,6 +260,7 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
240260
Returns:
241261
None: This method does not return a value
242262
"""
263+
on_conflict = _normalize_on_conflict(on_conflict)
243264
# Check whether the run already exist in the session
244265
run_ids = self.get_run_ids(session.subject_id, session.id)
245266

@@ -270,7 +291,8 @@ def write_run(self, run:Run, session:Session = None, protocol:Protocol = None, o
270291
# Write snapshot of the protocol
271292
protocol.to_file(run_metadata_filepath.parent / f'{run.id}_protocol_snapshot.json')
272293

273-
def write_subject(self, subject, on_conflict=OnConflictOpts.ERROR):
294+
def write_subject(self, subject, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
295+
on_conflict = _normalize_on_conflict(on_conflict)
274296
subject_id = subject.id
275297
subject_ids = self.get_subject_ids()
276298

@@ -305,7 +327,7 @@ def write_transducer(
305327
transducer,
306328
registration_surface_model_filepath: PathLike | None = None,
307329
transducer_body_model_filepath: PathLike | None = None,
308-
on_conflict: OnConflictOpts=OnConflictOpts.ERROR,
330+
on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR,
309331
) -> None:
310332
""" Writes a transducer object to database and copies the affiliated transducer data files to the database if provided. When a transducer that is already present in the database is being re-written,
311333
the associated model data files do not need to be provided if they have previously been added to the database.
@@ -316,6 +338,7 @@ def write_transducer(
316338
Returns:
317339
None: This method does not return a value
318340
"""
341+
on_conflict = _normalize_on_conflict(on_conflict)
319342
transducer_id = transducer.id
320343
transducer_ids = self.get_transducer_ids()
321344

@@ -365,7 +388,8 @@ def write_transducer(
365388

366389
self.logger.info(f"Added transducer with ID {transducer_id} to the database.")
367390

368-
def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict=OnConflictOpts.ERROR):
391+
def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
392+
on_conflict = _normalize_on_conflict(on_conflict)
369393
if not Path(volume_data_filepath).exists():
370394
raise ValueError(f'Volume data filepath does not exist: {volume_data_filepath}')
371395

@@ -421,10 +445,11 @@ def write_volume(self, subject_id, volume_id, volume_name, volume_data_filepath,
421445
if temp_nifti_path is not None and temp_nifti_path.exists():
422446
temp_nifti_path.unlink()
423447

424-
def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict=OnConflictOpts.ERROR):
448+
def write_photocollection(self, subject_id, session_id, reference_number: str, photo_paths: List[PathLike], on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
425449
""" Writes a photocollection to database and copies the associated
426450
photos into the database, specified by the subject, session, and
427451
reference_number of the photocollection."""
452+
on_conflict = _normalize_on_conflict(on_conflict)
428453

429454
photocollection_dir = Path(self.get_session_dir(subject_id, session_id)) / 'photocollections' / reference_number
430455

@@ -457,10 +482,11 @@ def write_photocollection(self, subject_id, session_id, reference_number: str, p
457482

458483
self.logger.info(f"Added photocollection with reference number {reference_number} for session {session_id} to the database.")
459484

460-
def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict=OnConflictOpts.ERROR):
485+
def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_data_filepath: str | None = None, texture_data_filepath: str | None = None, mtl_data_filepath: str | None = None, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
461486
""" Writes a photoscan object to database and copies the associated data filepaths into the database.
462487
While the model data file is required, the associated texture and .mtl files are optional and can be provided if present.
463488
When a photoscan that is already present in the database is being re-written,the associated data files do not need to be specified """
489+
on_conflict = _normalize_on_conflict(on_conflict)
464490

465491
photoscan_ids = self.get_photoscan_ids(subject_id, session_id)
466492
if photoscan.id in photoscan_ids:
@@ -518,7 +544,8 @@ def write_photoscan(self, subject_id, session_id, photoscan: Photoscan, model_da
518544

519545
self.logger.info(f"Added photoscan with ID {photoscan.id} for session {session_id} to the database.")
520546

521-
def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts=OnConflictOpts.ERROR):
547+
def write_solution(self, session:Session, solution:Solution, on_conflict: OnConflictOpts | str = OnConflictOpts.ERROR):
548+
on_conflict = _normalize_on_conflict(on_conflict)
522549
solution_ids = self.get_solution_ids(session.subject_id, session.id)
523550

524551
if solution.id in solution_ids:

tests/test_database.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,45 @@ def test_write_protocol(example_database: Database):
9393
reloaded_protocol = example_database.load_protocol(protocol.id)
9494
assert reloaded_protocol.name == "new_name"
9595

96+
97+
def test_on_conflict_accepts_enum_and_strings(example_database: Database):
98+
assert OnConflictOpts.OVERWRITE.value == "overwrite"
99+
100+
protocol = Protocol(name="bleh", id="a_protocol_with_string_conflict_option")
101+
example_database.write_protocol(protocol)
102+
103+
with pytest.raises(ValueError, match="already exists"):
104+
example_database.write_protocol(protocol, on_conflict="ERROR")
105+
106+
protocol.name = "skipped_name"
107+
example_database.write_protocol(protocol, on_conflict="SkIp")
108+
reloaded_protocol = example_database.load_protocol(protocol.id)
109+
assert reloaded_protocol.name == "bleh"
110+
111+
protocol.name = "overwritten_name"
112+
example_database.write_protocol(protocol, on_conflict="OVERWRITE")
113+
reloaded_protocol = example_database.load_protocol(protocol.id)
114+
assert reloaded_protocol.name == "overwritten_name"
115+
116+
example_database.delete_protocol("non_existent_protocol", on_conflict="skip")
117+
118+
user = User(name="initial_name", id="a_user_with_string_conflict_option")
119+
example_database.write_user(user)
120+
121+
user.name = "skipped_name"
122+
example_database.write_user(user, on_conflict="skip")
123+
reloaded_user = example_database.load_user(user.id)
124+
assert reloaded_user.name == "initial_name"
125+
126+
user.name = "overwritten_name"
127+
example_database.write_user(user, on_conflict="overwrite")
128+
reloaded_user = example_database.load_user(user.id)
129+
assert reloaded_user.name == "overwritten_name"
130+
131+
with pytest.raises(ValueError, match="Invalid 'on_conflict' option"):
132+
example_database.write_protocol(protocol, on_conflict="replace")
133+
134+
96135
def test_delete_protocol(example_database: Database):
97136
# Write a protocol
98137
protocol = Protocol(name="bleh", id="a_protocol_to_be_deleted")

0 commit comments

Comments
 (0)