Skip to content
Merged
14 changes: 7 additions & 7 deletions src/aind_data_schema/core/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ def validate_subject_specimen_ids(self):
# Return if no specimen procedures
if self.specimen_procedures:
subject_id = self.subject_id
specimen_id_vars = [spec_proc.specimen_id for spec_proc in self.specimen_procedures]
specimen_ids = []
for spec_id_var in specimen_id_vars:
if isinstance(spec_id_var, str):
specimen_ids.append(spec_id_var)
flat_specimen_ids = []
for spec_proc in self.specimen_procedures:
sid = spec_proc.specimen_id
if isinstance(sid, list):
flat_specimen_ids.extend(sid)
else:
specimen_ids.extend(spec_id_var)
flat_specimen_ids.append(sid)

if any(not subject_specimen_id_compatibility(subject_id, spec_id) for spec_id in specimen_ids):
if any(not subject_specimen_id_compatibility(subject_id, spec_id) for spec_id in flat_specimen_ids):
raise ValueError("specimen_id must be an extension of the subject_id.")

return self
Expand Down
14 changes: 14 additions & 0 deletions tests/test_procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PlanarSection,
PlanarSectioning,
Section,
Sectioning,
SectionOrientation,
SpecimenProcedure,
)
Expand Down Expand Up @@ -332,6 +333,19 @@ def test_validate_procedure_type(self):
)
)

self.assertIsNotNone(
SpecimenProcedure(
specimen_id="1000",
procedure_type="Sectioning",
start_date=date.fromisoformat("2020-10-10"),
end_date=date.fromisoformat("2020-10-11"),
experimenters=["Mam Moth"],
protocol_id=["10"],
notes=None,
procedure_details=[Sectioning(sections=[Section(output_specimen_id="1000_spinal")])],
)
)

def test_validate_procedure_type_multiple(self):
"""Test that error thrown when multiple types are passed to procedure_details"""

Expand Down
Loading