Skip to content

Commit 61ff3dc

Browse files
authored
feat: compute group:value pairs for tag status (#1673)
* feat: compute group:value pairs for tag status * chore: lint * tests: add coverage for properties
1 parent 6ebb311 commit 61ff3dc

2 files changed

Lines changed: 77 additions & 27 deletions

File tree

src/aind_data_schema/core/quality_control.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class QualityControl(DataCoreModel):
146146

147147
@property
148148
def tags(self) -> List[str]:
149-
"""Get all unique tags from all metrics
149+
"""Get all unique tag values from all metrics
150150
151151
Returns
152152
-------
@@ -158,6 +158,21 @@ def tags(self) -> List[str]:
158158
all_tags.extend(metric.tags.values())
159159
return list(set(all_tags))
160160

161+
@property
162+
def tag_pairs(self) -> List[str]:
163+
"""Get all unique tag key:value pairs from all metrics
164+
165+
Returns
166+
-------
167+
List[str]
168+
List of all unique tag key:value pairs across all metrics in 'key:value' format
169+
"""
170+
all_tag_pairs = []
171+
for metric in self.metrics:
172+
for key, value in metric.tags.items():
173+
all_tag_pairs.append(f"{key}:{value}")
174+
return list(set(all_tag_pairs))
175+
161176
@property
162177
def modalities(self) -> List[Modality.ONE_OF]:
163178
"""Get all unique modalities from all metrics
@@ -192,9 +207,9 @@ def compute_status(self):
192207
if self.metrics:
193208
computed_status = {}
194209

195-
# Compute tag statuses
196-
for tag in self.tags:
197-
computed_status[tag] = self.evaluate_status(tag=tag)
210+
# Compute tag statuses (using key:value format)
211+
for tag_pair in self.tag_pairs:
212+
computed_status[tag_pair] = self.evaluate_status(tag=tag_pair)
198213

199214
# Compute modality statuses
200215
for modality in self.modalities:
@@ -335,23 +350,36 @@ def _get_filtered_statuses(
335350
tag_filter: Optional[List[str]] = None,
336351
allow_tag_failures: List[str] = [],
337352
):
338-
"""Get the status of metrics filtered by modality, stage, tag, and date."""
353+
"""Get the status of metrics filtered by modality, stage, tag, and date.
354+
355+
tag_filter can contain either 'key:value' pairs or just tag values for backward compatibility.
356+
allow_tag_failures can contain either 'key:value' pairs or just tag values.
357+
"""
339358
filtered_statuses = []
340359
for metric in metrics:
341360
# Apply filters
342361
if modality_filter and metric.modality not in modality_filter:
343362
continue
344363
if stage_filter and metric.stage not in stage_filter:
345364
continue
346-
if tag_filter and not (metric.tags and any(t in metric.tags.values() for t in tag_filter)):
347-
continue
365+
if tag_filter:
366+
# Check if any of the filter tags match this metric's tags
367+
# Support both 'key:value' format and just values for backward compatibility
368+
metric_tag_pairs = [f"{k}:{v}" for k, v in metric.tags.items()]
369+
metric_tag_values = list(metric.tags.values())
370+
if not any(t in metric_tag_pairs or t in metric_tag_values for t in tag_filter):
371+
continue
348372

349373
# Get status at the specified date using the helper function
350374
status = _get_status_by_date(metric, date)
351-
# Check if any of our tag values are in the allow_tag_failures list
375+
# Check if any of our tag key:value pairs or values are in the allow_tag_failures list
352376
if status == Status.FAIL and metric.tags:
353-
metric_tag_values = set(metric.tags.values())
354-
if any(tag_value in allow_tag_failures for tag_value in metric_tag_values):
377+
metric_tag_pairs = [f"{k}:{v}" for k, v in metric.tags.items()]
378+
metric_tag_values = list(metric.tags.values())
379+
if any(
380+
tag_pair in allow_tag_failures or tag_value in allow_tag_failures
381+
for tag_pair, tag_value in zip(metric_tag_pairs, metric_tag_values)
382+
):
355383
status = Status.PASS
356384
filtered_statuses.append(status)
357385

tests/test_quality_control.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,28 @@ def test_tags_list_to_dict_conversion(self):
4747
self.assertIsInstance(metric.tags, dict)
4848
self.assertEqual(metric.tags, {"tag_1": "tag1", "tag_2": "tag2", "tag_3": "tag3"})
4949

50+
def test_tags_property(self):
51+
"""test that QualityControl.tags returns all unique tag values"""
52+
tags = quality_control.tags
53+
self.assertIsInstance(tags, list)
54+
self.assertIn("Probe A", tags)
55+
self.assertIn("Probe B", tags)
56+
self.assertIn("Probe C", tags)
57+
self.assertIn("Video 1", tags)
58+
self.assertIn("Video 2", tags)
59+
self.assertEqual(len(tags), 5)
60+
61+
def test_tag_pairs_property(self):
62+
"""test that QualityControl.tag_pairs returns all unique key:value pairs"""
63+
tag_pairs = quality_control.tag_pairs
64+
self.assertIsInstance(tag_pairs, list)
65+
self.assertIn("probe:Probe A", tag_pairs)
66+
self.assertIn("probe:Probe B", tag_pairs)
67+
self.assertIn("probe:Probe C", tag_pairs)
68+
self.assertIn("video:Video 1", tag_pairs)
69+
self.assertIn("video:Video 2", tag_pairs)
70+
self.assertEqual(len(tag_pairs), 5)
71+
5072
def test_overall_status(self):
5173
"""test that overall status goes to pass/pending/fail correctly"""
5274

@@ -186,7 +208,7 @@ def test_evaluation_status(self):
186208
)
187209
)
188210

189-
self.assertEqual(qc.evaluate_status(tag="Drift map"), Status.FAIL)
211+
self.assertEqual(qc.evaluate_status(tag="group:Drift map"), Status.FAIL)
190212

191213
def test_allowed_failed_metrics(self):
192214
"""Test that if you set the flag to allow failures that tags pass"""
@@ -224,17 +246,17 @@ def test_allowed_failed_metrics(self):
224246
default_grouping=["group"],
225247
)
226248

227-
self.assertEqual(qc.evaluate_status(tag="Drift map"), Status.PENDING)
249+
self.assertEqual(qc.evaluate_status(tag="group:Drift map"), Status.PENDING)
228250

229251
# Replace the pending evaluation with a fail, evaluation should not evaluate to pass
230252
qc.metrics[1].status_history[0].status = Status.FAIL
231253

232-
self.assertEqual(qc.evaluate_status(tag="Drift map"), Status.FAIL)
254+
self.assertEqual(qc.evaluate_status(tag="group:Drift map"), Status.FAIL)
233255

234256
# Now add the tag to allow_tag_failures
235-
qc.allow_tag_failures = ["Drift map"]
257+
qc.allow_tag_failures = ["group:Drift map"]
236258

237-
self.assertEqual(qc.evaluate_status(tag="Drift map"), Status.PASS)
259+
self.assertEqual(qc.evaluate_status(tag="group:Drift map"), Status.PASS)
238260

239261
def test_metric_history_order(self):
240262
"""Test that the order of the metric status history list is preserved when dumping"""
@@ -441,10 +463,10 @@ def test_status_filters(self):
441463
"behavior": Status.FAIL,
442464
"behavior-videos": Status.PENDING,
443465
"ecephys": Status.PASS,
444-
# Tags
445-
"test_group": Status.PASS,
446-
"test_group2": Status.FAIL,
447-
"tag1": Status.PENDING,
466+
# Tags (now in key:value format)
467+
"group:test_group": Status.PASS,
468+
"group:test_group2": Status.FAIL,
469+
"type:tag1": Status.PENDING,
448470
},
449471
)
450472

@@ -454,7 +476,7 @@ def test_status_filters(self):
454476
self.assertEqual(q.evaluate_status(modality=[Modality.ECEPHYS, Modality.BEHAVIOR]), Status.FAIL)
455477
self.assertEqual(q.evaluate_status(stage=Stage.RAW), Status.FAIL)
456478
self.assertEqual(q.evaluate_status(stage=Stage.PROCESSING), Status.PASS)
457-
self.assertEqual(q.evaluate_status(tag="tag1"), Status.PENDING)
479+
self.assertEqual(q.evaluate_status(tag="type:tag1"), Status.PENDING)
458480

459481
def test_status_date(self):
460482
"""QualityControl.status(date=) should return the correct status for the given date"""
@@ -635,7 +657,7 @@ def test_get_filtered_statuses_helper(self):
635657
shared_tag_statuses = _get_filtered_statuses(
636658
metrics=all_metrics,
637659
date=test_date,
638-
tag_filter=["shared_tag"],
660+
tag_filter=["group:shared_tag"],
639661
)
640662
self.assertEqual(len(shared_tag_statuses), 2) # Our BEHAVIOR and OPHYS test metrics
641663
self.assertIn(Status.PASS, shared_tag_statuses)
@@ -655,7 +677,7 @@ def test_get_filtered_statuses_helper(self):
655677
time_test_statuses = _get_filtered_statuses(
656678
metrics=all_metrics,
657679
date=earlier_date,
658-
tag_filter=["time_test"],
680+
tag_filter=["test:time_test"],
659681
)
660682
self.assertEqual(len(time_test_statuses), 1)
661683
self.assertEqual(time_test_statuses[0], Status.FAIL) # Should get the earlier FAIL status
@@ -664,8 +686,8 @@ def test_get_filtered_statuses_helper(self):
664686
ophys_fail_statuses = _get_filtered_statuses(
665687
metrics=all_metrics,
666688
date=test_date,
667-
tag_filter=["ophys_tag"],
668-
allow_tag_failures=["ophys_tag"],
689+
tag_filter=["type:ophys_tag"],
690+
allow_tag_failures=["type:ophys_tag"],
669691
)
670692
self.assertEqual(len(ophys_fail_statuses), 1)
671693
self.assertEqual(ophys_fail_statuses[0], Status.PASS) # FAIL converted to PASS
@@ -756,16 +778,16 @@ def test_helper_functions_integration(self):
756778
# Test status at different times
757779
early_date = datetime.fromisoformat("2020-02-01T00:00:00+00:00")
758780
# At early date: metric 1 is FAIL, metric 2 is PASS -> overall FAIL
759-
early_status = qc.evaluate_status(date=early_date, tag="time_sensitive")
781+
early_status = qc.evaluate_status(date=early_date, tag="group:time_sensitive")
760782
self.assertEqual(early_status, Status.FAIL)
761783

762784
# At test date: metric 1 is PASS, metric 2 is PASS -> overall PASS
763-
test_status = qc.evaluate_status(date=test_date, tag="time_sensitive")
785+
test_status = qc.evaluate_status(date=test_date, tag="group:time_sensitive")
764786
self.assertEqual(test_status, Status.PASS)
765787

766788
# At late date: metric 1 is PASS, metric 2 is FAIL -> overall FAIL
767789
late_date = datetime.fromisoformat("2020-08-01T00:00:00+00:00")
768-
late_status = qc.evaluate_status(date=late_date, tag="time_sensitive")
790+
late_status = qc.evaluate_status(date=late_date, tag="group:time_sensitive")
769791
self.assertEqual(late_status, Status.FAIL)
770792

771793
def test_backwards_compatibility_default_grouping(self):

0 commit comments

Comments
 (0)