Skip to content

Commit e77f5e6

Browse files
committed
feat: refactor data models and mapping functions to use Pydantic for improved type safety and validation
1 parent 70a78fe commit e77f5e6

9 files changed

Lines changed: 335 additions & 173 deletions

File tree

middleware/sql_to_arc/src/middleware/sql_to_arc/builder.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from arctrl import ( # type: ignore[import-untyped]
1010
ARC,
11+
ArcAssay,
12+
ArcStudy,
1113
ArcTable,
1214
CompositeCell,
1315
CompositeHeader,
@@ -22,57 +24,74 @@
2224
map_publication,
2325
map_study,
2426
)
25-
from middleware.sql_to_arc.models import ArcBuildData
27+
from middleware.sql_to_arc.models import (
28+
ArcBuildData,
29+
AssayRow,
30+
ContactRow,
31+
PublicationRow,
32+
StudyRow,
33+
)
2634

2735
logger = logging.getLogger(__name__)
2836

2937

30-
def _add_studies_to_arc(arc: ARC, study_rows: list[dict[str, Any]]) -> dict[str, Any]:
38+
def _add_studies_to_arc(arc: ARC, study_rows: list[StudyRow]) -> dict[str, ArcStudy]:
3139
"""Add studies to ARC and return study map."""
32-
study_map = {}
40+
study_map: dict[str, ArcStudy] = {}
3341
for s_row in study_rows:
3442
study = map_study(s_row)
3543
arc.AddRegisteredStudy(study)
36-
study_map[str(s_row["identifier"])] = study
44+
study_map[str(s_row.identifier)] = study
3745
return study_map
3846

3947

40-
def _add_assays_to_arc(arc: ARC, assay_rows: list[dict[str, Any]], study_map: dict[str, Any]) -> dict[str, Any]:
48+
def _add_assays_to_arc(arc: ARC, assay_rows: list[AssayRow], study_map: dict[str, ArcStudy]) -> dict[str, ArcAssay]:
4149
"""Add assays to ARC, link to studies, and return assay map."""
42-
assay_map = {}
50+
assay_map: dict[str, ArcAssay] = {}
4351
for a_row in assay_rows:
4452
assay = map_assay(a_row)
4553
arc.AddAssay(assay)
46-
assay_map[str(a_row["identifier"])] = assay
54+
assay_map[str(a_row.identifier)] = assay
4755

4856
# Link Assay to Studies
49-
study_ref_json = a_row.get("study_ref")
50-
if not study_ref_json:
51-
continue
57+
_link_assay_to_studies(assay, a_row.study_ref, study_map)
58+
59+
return assay_map
60+
5261

62+
def _link_assay_to_studies(assay: ArcAssay, study_ref_val: Any, study_map: dict[str, ArcStudy]) -> None:
63+
"""Link an assay to one or more studies based on the study_ref value."""
64+
if not study_ref_val:
65+
return
66+
67+
if isinstance(study_ref_val, str):
5368
try:
54-
study_refs = json.loads(study_ref_json)
69+
study_refs = json.loads(study_ref_val)
5570
if isinstance(study_refs, list):
5671
for s_ref in study_refs:
57-
if s_ref in study_map:
58-
study_map[s_ref].RegisterAssay(assay.Identifier)
72+
if str(s_ref) in study_map:
73+
study_map[str(s_ref)].RegisterAssay(assay.Identifier)
74+
return
5975
except json.JSONDecodeError:
76+
# Handle single ID if it's not JSON (fall through)
6077
pass
6178

62-
return assay_map
79+
# Handle single ID (string or int)
80+
if str(study_ref_val) in study_map:
81+
study_map[str(study_ref_val)].RegisterAssay(assay.Identifier)
6382

6483

6584
def _add_contacts_to_arc(
6685
arc: ARC,
6786
inv_id: str,
68-
contacts: list[dict[str, Any]],
69-
study_map: dict[str, Any],
70-
assay_map: dict[str, Any],
87+
contacts: list[ContactRow],
88+
study_map: dict[str, ArcStudy],
89+
assay_map: dict[str, ArcAssay],
7190
) -> None:
7291
"""Add contacts to investigation, studies, and assays."""
7392
# Investigation contacts
7493
inv_contacts = [
75-
c for c in contacts if c.get("investigation_ref") == inv_id and c.get("target_type") == "investigation"
94+
c for c in contacts if str(c.investigation_ref) == inv_id and getattr(c, "target_type", None) == "investigation"
7695
]
7796
for c_row in inv_contacts:
7897
arc.Contacts.append(map_contact(c_row))
@@ -82,7 +101,9 @@ def _add_contacts_to_arc(
82101
stu_contacts = [
83102
c
84103
for c in contacts
85-
if c.get("investigation_ref") == inv_id and c.get("target_type") == "study" and c.get("target_ref") == s_id
104+
if str(c.investigation_ref) == inv_id
105+
and getattr(c, "target_type", None) == "study"
106+
and str(getattr(c, "target_ref", None)) == s_id
86107
]
87108
for c_row in stu_contacts:
88109
study.Contacts.append(map_contact(c_row))
@@ -92,19 +113,26 @@ def _add_contacts_to_arc(
92113
ass_contacts = [
93114
c
94115
for c in contacts
95-
if c.get("investigation_ref") == inv_id and c.get("target_type") == "assay" and c.get("target_ref") == a_id
116+
if str(c.investigation_ref) == inv_id
117+
and getattr(c, "target_type", None) == "assay"
118+
and str(getattr(c, "target_ref", None)) == a_id
96119
]
97120
for c_row in ass_contacts:
98121
assay.Performers.append(map_contact(c_row))
99122

100123

101124
def _add_publications_to_arc(
102-
arc: ARC, inv_id: str, publications: list[dict[str, Any]], study_map: dict[str, Any]
125+
arc: ARC,
126+
inv_id: str,
127+
publications: list[PublicationRow],
128+
study_map: dict[str, ArcStudy],
103129
) -> None:
104130
"""Add publications to investigation and studies."""
105131
# Investigation publications
106132
inv_pubs = [
107-
p for p in publications if p.get("investigation_ref") == inv_id and p.get("target_type") == "investigation"
133+
p
134+
for p in publications
135+
if str(p.investigation_ref) == inv_id and getattr(p, "target_type", None) == "investigation"
108136
]
109137
for p_row in inv_pubs:
110138
arc.Publications.append(map_publication(p_row))
@@ -114,7 +142,9 @@ def _add_publications_to_arc(
114142
stu_pubs = [
115143
p
116144
for p in publications
117-
if p.get("investigation_ref") == inv_id and p.get("target_type") == "study" and p.get("target_ref") == s_id
145+
if str(p.investigation_ref) == inv_id
146+
and getattr(p, "target_type", None) == "study"
147+
and str(getattr(p, "target_ref", None)) == s_id
118148
]
119149
for p_row in stu_pubs:
120150
study.Publications.append(map_publication(p_row))
@@ -271,16 +301,16 @@ def build_single_arc_task(data: ArcBuildData) -> str:
271301
This function is designed to run in a separate process.
272302
It returns the JSON representation to minimize memory footprint in the main process.
273303
"""
274-
inv_id = str(data.investigation_row["identifier"])
304+
inv_id = str(data.investigation_row.identifier)
275305

276306
try:
277307
# Map Investigation and create ARC
278308
arc_inv = map_investigation(data.investigation_row)
279309
arc = ARC.from_arc_investigation(arc_inv)
280310

281311
# Identify relevant studies and assays
282-
relevant_studies = [s for s in data.studies if s.get("investigation_ref") == inv_id]
283-
relevant_assays = [a for a in data.assays if a.get("investigation_ref") == inv_id]
312+
relevant_studies = [s for s in data.studies if str(s.investigation_ref) == inv_id]
313+
relevant_assays = [a for a in data.assays if str(a.investigation_ref) == inv_id]
284314

285315
# Add studies and assays
286316
study_map = _add_studies_to_arc(arc, relevant_studies)

middleware/sql_to_arc/src/middleware/sql_to_arc/database.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
1616
from sqlalchemy.sql import select
1717

18+
from middleware.sql_to_arc.models import (
19+
AssayRow,
20+
ContactRow,
21+
InvestigationRow,
22+
PublicationRow,
23+
StudyRow,
24+
)
25+
1826
# Define metadata
1927
metadata = MetaData()
2028

@@ -126,55 +134,55 @@ def __init__(self, connection_string: str) -> None:
126134
"""Initialize database with connection string."""
127135
self.engine: AsyncEngine = create_async_engine(connection_string, echo=False)
128136

129-
async def stream_investigations(self, limit: int | None = None) -> AsyncGenerator[dict[str, Any], None]:
137+
async def stream_investigations(self, limit: int | None = None) -> AsyncGenerator[InvestigationRow, None]:
130138
"""Stream investigations using a server-side cursor."""
131139
async with self.engine.connect() as conn:
132140
stmt = select(v_investigation)
133141
if limit:
134142
stmt = stmt.limit(limit)
135143
result = await conn.stream(stmt.execution_options(stream_results=True))
136144
async for row in result.mappings():
137-
yield dict(row)
145+
yield InvestigationRow.model_validate(row)
138146

139-
async def stream_studies(self, investigation_ids: list[str]) -> AsyncGenerator[dict[str, Any], None]:
147+
async def stream_studies(self, investigation_ids: list[str]) -> AsyncGenerator[StudyRow, None]:
140148
"""Stream studies for given investigations."""
141149
if not investigation_ids:
142150
return
143151
async with self.engine.connect() as conn:
144152
stmt = select(v_study).where(v_study.c.investigation_ref.in_(investigation_ids))
145153
result = await conn.stream(stmt.execution_options(stream_results=True))
146154
async for row in result.mappings():
147-
yield dict(row)
155+
yield StudyRow.model_validate(row)
148156

149-
async def stream_assays(self, investigation_ids: list[str]) -> AsyncGenerator[dict[str, Any], None]:
157+
async def stream_assays(self, investigation_ids: list[str]) -> AsyncGenerator[AssayRow, None]:
150158
"""Stream assays for given investigations."""
151159
if not investigation_ids:
152160
return
153161
async with self.engine.connect() as conn:
154162
stmt = select(v_assay).where(v_assay.c.investigation_ref.in_(investigation_ids))
155163
result = await conn.stream(stmt.execution_options(stream_results=True))
156164
async for row in result.mappings():
157-
yield dict(row)
165+
yield AssayRow.model_validate(row)
158166

159-
async def stream_contacts(self, investigation_ids: list[str]) -> AsyncGenerator[dict[str, Any], None]:
167+
async def stream_contacts(self, investigation_ids: list[str]) -> AsyncGenerator[ContactRow, None]:
160168
"""Stream contacts for given investigations."""
161169
if not investigation_ids:
162170
return
163171
async with self.engine.connect() as conn:
164172
stmt = select(v_contact).where(v_contact.c.investigation_ref.in_(investigation_ids))
165173
result = await conn.stream(stmt.execution_options(stream_results=True))
166174
async for row in result.mappings():
167-
yield dict(row)
175+
yield ContactRow.model_validate(row)
168176

169-
async def stream_publications(self, investigation_ids: list[str]) -> AsyncGenerator[dict[str, Any], None]:
177+
async def stream_publications(self, investigation_ids: list[str]) -> AsyncGenerator[PublicationRow, None]:
170178
"""Stream publications for given investigations."""
171179
if not investigation_ids:
172180
return
173181
async with self.engine.connect() as conn:
174182
stmt = select(v_publication).where(v_publication.c.investigation_ref.in_(investigation_ids))
175183
result = await conn.stream(stmt.execution_options(stream_results=True))
176184
async for row in result.mappings():
177-
yield dict(row)
185+
yield PublicationRow.model_validate(row)
178186

179187
async def stream_annotation_tables(self, investigation_ids: list[str]) -> AsyncGenerator[dict[str, Any], None]:
180188
"""Stream annotation tables for given investigations."""

0 commit comments

Comments
 (0)