Skip to content

Commit 5806136

Browse files
authored
fix: properly handle None dependency_graph during merge (#1657)
* fix: properly handle None dependency_graph during merge * fix: raise error for one None one not * fix: change merge code to work when one graph exists and the other doesn't * chore: refactor to simplify __add__ function * test: ensure link occurred in process merge * chore: fix dependencies due to new adsm
1 parent 00e604b commit 5806136

5 files changed

Lines changed: 156 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ readme = "README.md"
1414
dynamic = ["version"]
1515

1616
dependencies = [
17-
'aind-data-schema-models>=4.2.11',
17+
'aind-data-schema-models>=4.2.11,<5',
1818
'pydantic>=2.7, <2.12',
1919
'semver'
2020
]

src/aind_data_schema/core/processing.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from aind_data_schema.base import AwareDatetimeWithDefault, DataCoreModel, DataModel, GenericModel
1313
from aind_data_schema.components.identifiers import Code
1414
from aind_data_schema.components.wrappers import AssetPath
15-
from aind_data_schema.utils.merge import merge_notes, merge_optional_list
15+
from aind_data_schema.utils.merge import merge_notes, merge_optional_list, merge_process_graph
1616
from aind_data_schema.utils.validators import TimeValidation
1717

1818

@@ -135,12 +135,14 @@ def rename_process(self, old_name: str, new_name: str) -> None:
135135
break
136136
else:
137137
raise ValueError(f"Process '{old_name}' not found in data_processes.")
138+
138139
# rename in dependency_graph
139-
self.dependency_graph[new_name] = self.dependency_graph.pop(old_name)
140-
# replace old_name in dependency_graph values
141-
for value in self.dependency_graph.values():
142-
if old_name in value:
143-
value[value.index(old_name)] = new_name
140+
if self.dependency_graph:
141+
self.dependency_graph[new_name] = self.dependency_graph.pop(old_name)
142+
# replace old_name in dependency_graph values
143+
for value in self.dependency_graph.values():
144+
if old_name in value:
145+
value[value.index(old_name)] = new_name
144146

145147
@model_validator(mode="after")
146148
def order_processes(self) -> "Processing":
@@ -248,19 +250,19 @@ def __add__(self, other: "Processing") -> "Processing":
248250
i += 1
249251
other.rename_process(name, new_name)
250252

251-
# Merge process graphs - start with self's graph and update with other's graph
252-
combined_process_graph = self.dependency_graph.copy()
253-
combined_process_graph.update(other.dependency_graph)
253+
merged_graph = merge_process_graph(
254+
self.dependency_graph, other.dependency_graph, self.data_processes, other.data_processes
255+
)
254256

255257
# link self's output to other's input
256258
# note that this only makes sense if self has a single output process
257259
# and other has a single input process
258-
if len(self.data_processes) > 0 and len(other.data_processes) > 0:
259-
combined_process_graph[other.data_processes[0].name] = [self.data_processes[-1].name]
260+
if merged_graph and len(self.data_processes) > 0 and len(other.data_processes) > 0:
261+
merged_graph[other.data_processes[0].name] = [self.data_processes[-1].name]
260262

261263
return Processing(
262264
pipelines=merge_optional_list(self.pipelines, other.pipelines),
263265
data_processes=self.data_processes + other.data_processes,
264-
dependency_graph=combined_process_graph,
266+
dependency_graph=merged_graph,
265267
notes=merge_notes(self.notes, other.notes),
266268
)

src/aind_data_schema/utils/merge.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,34 @@
44
from typing import Any, List, Optional
55

66

7+
def merge_process_graph(
8+
graph1: Optional[dict],
9+
graph2: Optional[dict],
10+
processes1: List[Any],
11+
processes2: List[Any],
12+
) -> Optional[dict]:
13+
"""Merge two process dependency graphs"""
14+
15+
# Merge process graphs - start with self's graph and update with other's graph
16+
if graph1 and graph2:
17+
merged_graph = graph1.copy()
18+
merged_graph.update(graph2)
19+
elif graph1 and not graph2:
20+
merged_graph = graph1.copy()
21+
# Add entries for other's processes
22+
for process in processes2:
23+
merged_graph[process.name] = []
24+
elif graph2 and not graph1:
25+
merged_graph = graph2.copy()
26+
# Add entries for self's processes
27+
for process in processes1:
28+
merged_graph[process.name] = []
29+
else:
30+
merged_graph = None
31+
32+
return merged_graph
33+
34+
735
def merge_str_tuple_lists(
836
a: list[str | tuple[str, ...]], b: list[str | tuple[str, ...]]
937
) -> list[str | tuple[str, ...]]:

tests/test_composability_merge.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,85 @@ def test_add_processing_objects(self):
311311
self.assertEqual(combined.data_processes[2].name, "Analysis_2")
312312
self.assertEqual(combined.data_processes[3].name, "Analysis_3")
313313

314+
def test_merge_dependency_graph(self):
315+
"""Test merging dependency graphs"""
316+
317+
t = datetime(2022, 11, 22, 8, 43, 00, tzinfo=timezone.utc)
318+
319+
# Test None/None case
320+
p1 = Processing(
321+
data_processes=[
322+
DataProcess(
323+
experimenters=["Dr. Dan"],
324+
process_type=ProcessName.ANALYSIS,
325+
stage=ProcessStage.PROCESSING,
326+
start_date_time=t,
327+
code=Code(url="https://example.com", version="1.0"),
328+
)
329+
],
330+
dependency_graph=None,
331+
)
332+
333+
p2 = Processing(
334+
data_processes=[
335+
DataProcess(
336+
experimenters=["Dr. Jane"],
337+
process_type=ProcessName.COMPRESSION,
338+
stage=ProcessStage.PROCESSING,
339+
start_date_time=t,
340+
code=Code(url="https://example.com", version="1.0"),
341+
)
342+
],
343+
dependency_graph=None,
344+
)
345+
346+
combined = p1 + p2
347+
self.assertIsNone(combined.dependency_graph)
348+
349+
# Test both set case
350+
p3 = Processing(
351+
data_processes=[
352+
DataProcess(
353+
experimenters=["Dr. Dan"],
354+
process_type=ProcessName.ANALYSIS,
355+
stage=ProcessStage.PROCESSING,
356+
start_date_time=t,
357+
code=Code(url="https://example.com", version="1.0"),
358+
)
359+
],
360+
dependency_graph={"Analysis": []},
361+
)
362+
363+
p4 = Processing(
364+
data_processes=[
365+
DataProcess(
366+
experimenters=["Dr. Jane"],
367+
process_type=ProcessName.COMPRESSION,
368+
stage=ProcessStage.PROCESSING,
369+
start_date_time=t,
370+
code=Code(url="https://example.com", version="1.0"),
371+
)
372+
],
373+
dependency_graph={"Compression": []},
374+
)
375+
376+
combined = p3 + p4
377+
self.assertIsNotNone(combined.dependency_graph)
378+
self.assertEqual(len(combined.dependency_graph), 2)
379+
self.assertIn("Analysis", combined.dependency_graph)
380+
self.assertIn("Compression", combined.dependency_graph)
381+
self.assertEqual(combined.dependency_graph["Compression"], ["Analysis"])
382+
383+
# Test self has graph, other doesn't
384+
combined = p3 + p2
385+
self.assertIsNotNone(combined.dependency_graph)
386+
self.assertIn("Analysis", combined.dependency_graph)
387+
388+
# Test other has graph, self doesn't
389+
combined = p2 + p4
390+
self.assertIsNotNone(combined.dependency_graph)
391+
self.assertIn("Compression", combined.dependency_graph)
392+
314393

315394
if __name__ == "__main__":
316395
unittest.main()

tests/test_utils_merge.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
""" Tests for merge utilities """
22

33
import unittest
4+
from unittest.mock import Mock
45

56
from aind_data_schema.components.coordinates import CoordinateSystemLibrary
67
from aind_data_schema.utils.merge import (
78
merge_notes,
89
merge_optional_list,
910
merge_coordinate_systems,
1011
merge_str_tuple_lists,
12+
merge_process_graph,
1113
)
1214

1315

@@ -193,5 +195,37 @@ def test_both_different(self):
193195
merge_coordinate_systems(self.CSA, self.CSB)
194196

195197

198+
class MergeProcessGraphTests(unittest.TestCase):
199+
"""Tests for merge_process_graph"""
200+
201+
def test_both_graphs_present(self):
202+
"""Test merging when both graphs are present"""
203+
graph1 = {"proc1": ["proc2"], "proc2": []}
204+
graph2 = {"proc3": ["proc4"], "proc4": []}
205+
result = merge_process_graph(graph1, graph2, [], [])
206+
self.assertEqual(result, {"proc1": ["proc2"], "proc2": [], "proc3": ["proc4"], "proc4": []})
207+
208+
def test_only_first_graph(self):
209+
"""Test when only first graph is present"""
210+
graph1 = {"proc1": ["proc2"], "proc2": []}
211+
proc = Mock()
212+
proc.name = "proc3"
213+
result = merge_process_graph(graph1, None, [], [proc])
214+
self.assertEqual(result, {"proc1": ["proc2"], "proc2": [], "proc3": []})
215+
216+
def test_only_second_graph(self):
217+
"""Test when only second graph is present"""
218+
graph2 = {"proc3": ["proc4"], "proc4": []}
219+
proc = Mock()
220+
proc.name = "proc1"
221+
result = merge_process_graph(None, graph2, [proc], [])
222+
self.assertEqual(result, {"proc3": ["proc4"], "proc4": [], "proc1": []})
223+
224+
def test_both_graphs_none(self):
225+
"""Test when both graphs are None"""
226+
result = merge_process_graph(None, None, [], [])
227+
self.assertIsNone(result)
228+
229+
196230
if __name__ == "__main__":
197231
unittest.main()

0 commit comments

Comments
 (0)