|
1 | 1 | """Database configuration.""" |
2 | 2 |
|
| 3 | +import json |
3 | 4 | import uuid |
4 | 5 | import datetime |
| 6 | +from pathlib import Path |
5 | 7 | from typing import Set, List, Any |
6 | 8 | from contextvars import ContextVar |
7 | 9 | import enum |
@@ -61,6 +63,49 @@ class Base(DeclarativeBase): |
61 | 63 | DATETIME_TESTING = datetime.datetime(2024, 12, 26, 19, 37, 59, 753357) |
62 | 64 |
|
63 | 65 |
|
| 66 | +def _setup_test_institutions(session: Session) -> None: |
| 67 | + """Load optional local institution display data from config/local_inst_data.json (gitignored).""" |
| 68 | + file = Path("config/local_inst_data.json") |
| 69 | + if file.exists(): |
| 70 | + with open(file) as f: |
| 71 | + for inst in json.load(f): |
| 72 | + session.merge( |
| 73 | + InstTable( |
| 74 | + id=uuid.UUID(inst["inst_id"]), |
| 75 | + name=inst["name"], |
| 76 | + state=inst.get("state"), |
| 77 | + pdp_id=inst.get("pdp_id"), |
| 78 | + created_at=DATETIME_TESTING, |
| 79 | + updated_at=DATETIME_TESTING, |
| 80 | + created_by=LOCAL_USER_UUID, |
| 81 | + ) |
| 82 | + ) |
| 83 | + schemas_by_file_id = { |
| 84 | + f["data_id"]: f.get("schemas", []) for f in inst.get("files", []) |
| 85 | + } |
| 86 | + for batch in inst.get("batches", []): |
| 87 | + batch_table = BatchTable( |
| 88 | + id=uuid.UUID(batch["batch_id"]), |
| 89 | + inst_id=uuid.UUID(inst["inst_id"]), |
| 90 | + name=batch["name"], |
| 91 | + created_at=DATETIME_TESTING, |
| 92 | + updated_at=DATETIME_TESTING, |
| 93 | + created_by=LOCAL_USER_UUID, |
| 94 | + ) |
| 95 | + for file_name, file_id in batch["file_names_to_ids"].items(): |
| 96 | + batch_table.files.add( |
| 97 | + session.merge( |
| 98 | + FileTable( |
| 99 | + id=uuid.UUID(file_id), |
| 100 | + inst_id=uuid.UUID(inst["inst_id"]), |
| 101 | + name=file_name, |
| 102 | + schemas=schemas_by_file_id.get(file_id, []), |
| 103 | + ) |
| 104 | + ) # type: ignore |
| 105 | + ) |
| 106 | + session.merge(batch_table) |
| 107 | + |
| 108 | + |
64 | 109 | @event.listens_for(Mapper, "before_insert") |
65 | 110 | @event.listens_for(Mapper, "before_update") |
66 | 111 | def validate_string_lengths(mapper, connection, target): |
@@ -121,6 +166,7 @@ def init_db(env: str) -> None: |
121 | 166 | ) |
122 | 167 | # Create test files and batches for LOCAL environment |
123 | 168 | if env == "LOCAL": |
| 169 | + _setup_test_institutions(session) |
124 | 170 | # Create test files |
125 | 171 | test_file_1 = FileTable( |
126 | 172 | id=uuid.UUID("f0bb3a20-6d92-4254-afed-6a72f43c562a"), |
|
0 commit comments