Skip to content

Commit d8a3c3b

Browse files
committed
feat: refactor models to use NamedTuple for ArcBuildData and RelatedDataBatch, enhancing performance and memory efficiency
1 parent 37d8c33 commit d8a3c3b

3 files changed

Lines changed: 30 additions & 22 deletions

File tree

middleware/sql_to_arc/src/middleware/sql_to_arc/models.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import concurrent.futures
44
from datetime import datetime
5-
from typing import Any
5+
from typing import Any, NamedTuple
66

77
from pydantic import BaseModel, ConfigDict
88

@@ -83,7 +83,7 @@ class ContactRow(BaseModel):
8383
model_config = ConfigDict(extra="allow", coerce_numbers_to_str=True, from_attributes=True)
8484

8585

86-
class ArcBuildData(BaseModel):
86+
class ArcBuildData(NamedTuple):
8787
"""Data bundle for building a single ARC."""
8888

8989
investigation_row: InvestigationRow
@@ -93,8 +93,6 @@ class ArcBuildData(BaseModel):
9393
publications: list[PublicationRow]
9494
annotations: list[dict[str, Any]]
9595

96-
model_config = ConfigDict(arbitrary_types_allowed=True)
97-
9896

9997
class WorkerContext(BaseModel):
10098
"""Context data for a worker process."""
@@ -114,7 +112,7 @@ class WorkerContext(BaseModel):
114112
model_config = ConfigDict(arbitrary_types_allowed=True)
115113

116114

117-
class RelatedDataBatch(BaseModel):
115+
class RelatedDataBatch(NamedTuple):
118116
"""Batch of related data grouped by investigation ID."""
119117

120118
studies_by_inv: dict[str, list[StudyRow]]
@@ -124,5 +122,3 @@ class RelatedDataBatch(BaseModel):
124122
anns_by_inv: dict[str, list[dict[str, Any]]]
125123
study_count: int
126124
assay_count: int
127-
128-
model_config = ConfigDict(arbitrary_types_allowed=True)

middleware/sql_to_arc/src/middleware/sql_to_arc/processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import multiprocessing
88
from collections import defaultdict
99
from collections.abc import AsyncGenerator
10+
from dataclasses import dataclass
1011
from typing import Any, TypeVar
1112

1213
from opentelemetry import trace
13-
from pydantic import BaseModel, ConfigDict
14+
from pydantic import BaseModel
1415

1516
from middleware.api_client import ApiClient, ApiClientError
1617
from middleware.sql_to_arc.builder import build_single_arc_task
@@ -171,7 +172,8 @@ async def group_stream(
171172
)
172173

173174

174-
class WorkerResources(BaseModel):
175+
@dataclass(slots=True)
176+
class WorkerResources:
175177
"""Orchestration resources shared across investigation tasks."""
176178

177179
client: ApiClient
@@ -180,8 +182,6 @@ class WorkerResources(BaseModel):
180182
executor: concurrent.futures.Executor
181183
semaphore: asyncio.Semaphore
182184

183-
model_config = ConfigDict(arbitrary_types_allowed=True)
184-
185185

186186
def _spawn_investigation_task(
187187
investigation: InvestigationRow,

middleware/sql_to_arc/tests/unit/test_builder.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import pytest
77

88
from middleware.sql_to_arc.builder import build_single_arc_task
9-
from middleware.sql_to_arc.models import ArcBuildData
9+
from middleware.sql_to_arc.models import (
10+
ArcBuildData,
11+
AssayRow,
12+
ContactRow,
13+
InvestigationRow,
14+
PublicationRow,
15+
StudyRow,
16+
)
1017

1118

1219
@pytest.fixture
@@ -97,7 +104,12 @@ def sample_publications() -> list[dict[str, Any]]:
97104
def test_build_simple_arc(sample_investigation: dict[str, Any]) -> None:
98105
"""Test building a basic ARC structure from investigation data."""
99106
arc_data = ArcBuildData(
100-
investigation_row=sample_investigation, studies=[], assays=[], contacts=[], publications=[], annotations=[]
107+
investigation_row=InvestigationRow.model_validate(sample_investigation),
108+
studies=[],
109+
assays=[],
110+
contacts=[],
111+
publications=[],
112+
annotations=[],
101113
)
102114
arc_json = build_single_arc_task(arc_data)
103115
assert isinstance(arc_json, str)
@@ -115,9 +127,9 @@ def test_build_arc_with_study_and_assay(
115127
) -> None:
116128
"""Test building an ARC with nested study and assay structures."""
117129
arc_data = ArcBuildData(
118-
investigation_row=sample_investigation,
119-
studies=sample_studies,
120-
assays=sample_assays,
130+
investigation_row=InvestigationRow.model_validate(sample_investigation),
131+
studies=[StudyRow.model_validate(s) for s in sample_studies],
132+
assays=[AssayRow.model_validate(a) for a in sample_assays],
121133
contacts=[],
122134
publications=[],
123135
annotations=[],
@@ -142,11 +154,11 @@ def test_build_arc_with_contacts_and_pubs(
142154
) -> None:
143155
"""Test building an ARC with contacts and publications at both investigation and study levels."""
144156
arc_data = ArcBuildData(
145-
investigation_row=sample_investigation,
146-
studies=sample_studies,
157+
investigation_row=InvestigationRow.model_validate(sample_investigation),
158+
studies=[StudyRow.model_validate(s) for s in sample_studies],
147159
assays=[],
148-
contacts=sample_contacts,
149-
publications=sample_publications,
160+
contacts=[ContactRow.model_validate(c) for c in sample_contacts],
161+
publications=[PublicationRow.model_validate(p) for p in sample_publications],
150162
annotations=[],
151163
)
152164
arc_json = build_single_arc_task(arc_data)
@@ -167,8 +179,8 @@ def test_build_ignores_irrelevant_data(sample_investigation: dict[str, Any]) ->
167179
other_study = {"identifier": "styX", "investigation_ref": "inv2"}
168180

169181
arc_data = ArcBuildData(
170-
investigation_row=sample_investigation,
171-
studies=[other_study],
182+
investigation_row=InvestigationRow.model_validate(sample_investigation),
183+
studies=[StudyRow.model_validate(other_study)],
172184
assays=[],
173185
contacts=[],
174186
publications=[],

0 commit comments

Comments
 (0)