-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathprocessing.py
More file actions
278 lines (223 loc) · 12 KB
/
Copy pathprocessing.py
File metadata and controls
278 lines (223 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""schema for processing"""
import re
import warnings
from enum import Enum
from typing import Annotated, Dict, List, Literal, Optional
from aind_data_schema_models.process_names import ProcessName
from aind_data_schema_models.units import MemoryUnit, UnitlessUnit
from pydantic import Field, SkipValidation, ValidationInfo, field_validator, model_validator
from aind_data_schema.base import AwareDatetimeWithDefault, DataCoreModel, DataModel, GenericModel
from aind_data_schema.components.identifiers import Code
from aind_data_schema.components.wrappers import AssetPath
from aind_data_schema.utils.merge import merge_notes, merge_optional_list, merge_process_graph
from aind_data_schema.utils.validators import TimeValidation
class ProcessStage(str, Enum):
"""Stages of processing"""
PROCESSING = "Processing"
ANALYSIS = "Analysis"
class ResourceTimestamped(DataModel):
"""Description of resource usage at a moment in time"""
timestamp: AwareDatetimeWithDefault = Field(..., title="Timestamp")
usage: float = Field(..., title="Usage")
class ResourceUsage(DataModel):
"""Description of resources used by a process"""
os: str = Field(..., title="Operating system")
architecture: str = Field(..., title="Architecture")
cpu: Optional[str] = Field(default=None, title="CPU name")
cpu_cores: Optional[int] = Field(default=None, title="CPU cores")
gpu: Optional[str] = Field(default=None, title="GPU name")
system_memory: Optional[float] = Field(default=None, title="System memory")
system_memory_unit: Optional[MemoryUnit] = Field(default=None, title="System memory unit")
ram: Optional[float] = Field(default=None, title="System RAM")
ram_unit: Optional[MemoryUnit] = Field(default=None, title="Ram unit")
cpu_usage: Optional[List[ResourceTimestamped]] = Field(default=None, title="CPU usage")
gpu_usage: Optional[List[ResourceTimestamped]] = Field(default=None, title="GPU usage")
ram_usage: Optional[List[ResourceTimestamped]] = Field(default=None, title="RAM usage")
usage_unit: str = Field(default=UnitlessUnit.PERCENT, title="Usage unit")
class DataProcess(DataModel):
"""Description of a single processing step"""
process_type: ProcessName = Field(..., title="Process type")
name: str = Field(
default="",
title="Name",
description=("Unique name of the processing step.", " If not provided, the type will be used as the name."),
)
stage: ProcessStage = Field(..., title="Processing stage")
code: Code = Field(..., title="Code", description="Code used for processing")
experimenters: List[str] = Field(..., title="Experimenters", description="People responsible for processing")
pipeline_name: Optional[str] = Field(
default=None, title="Pipeline name", description="Pipeline names must exist in Processing.pipelines"
)
start_date_time: Annotated[AwareDatetimeWithDefault, TimeValidation.AFTER] = Field(..., title="Start date time")
end_date_time: Optional[Annotated[AwareDatetimeWithDefault, TimeValidation.AFTER]] = Field(
default=None, title="End date time"
)
output_path: Optional[List[AssetPath]] = Field(
default=None, title="Output path", description="Path to processing outputs, if stored."
)
output_parameters: Optional[GenericModel] = Field(default=None, description="Output parameters", title="Outputs")
notes: Optional[str] = Field(default=None, title="Notes", validate_default=True)
resources: Optional[ResourceUsage] = Field(default=None, title="Process resource usage")
@field_validator("output_path", mode="before")
def validate_output_path(cls, value) -> Optional[List[AssetPath]]:
"""Validator for output_path to ensure it's a list even if a single path is provided
"""
if value is None:
return value
if not isinstance(value, list):
value = [value]
return [AssetPath(path) for path in value]
@field_validator("notes", mode="after")
def validate_other(cls, value: Optional[str], info: ValidationInfo) -> Optional[str]:
"""Validator for other/notes"""
if info.data.get("process_type") == ProcessName.OTHER and not value:
raise ValueError(
"Notes cannot be empty if 'process_type' is Other. Describe the type of processing in the notes field."
)
return value
@model_validator(mode="after")
def fill_default_name(self) -> "DataProcess":
"""Fill in default name if not provided"""
if not self.name:
self.name = self.process_type
return self
class Processing(DataCoreModel):
"""Description of all processes run on data"""
_DESCRIBED_BY_URL: str = DataCoreModel._DESCRIBED_BY_BASE_URL.default + "aind_data_schema/core/processing.py"
describedBy: str = Field(default=_DESCRIBED_BY_URL, json_schema_extra={"const": _DESCRIBED_BY_URL})
schema_version: SkipValidation[Literal["2.2.5"]] = Field(default="2.2.5")
data_processes: List[DataProcess] = Field(..., title="Data processing")
pipelines: Optional[List[Code]] = Field(
default=None,
title="Pipelines",
description=(
"For processing done with pipelines, list the repositories here. Pipelines must use the name field "
",and be referenced in the pipeline_name field of a DataProcess."
),
)
notes: Optional[str] = Field(default=None, title="Notes")
dependency_graph: Optional[Dict[str, List[str]]] = Field(
default=None,
title="Dependency graph",
description=(
"Directed graph of processing step dependencies. Each key is a process name, and the value is a list of "
"process names that are inputs to that process."
),
)
@property
def process_names(self) -> List[str]:
"""Return the names of data processes"""
return [process.name for process in self.data_processes]
def rename_process(self, old_name: str, new_name: str) -> None:
"""Rename a process in the processing object, including all references"""
for process in self.data_processes:
if process.name == old_name:
process.name = new_name
break
else:
raise ValueError(f"Process '{old_name}' not found in data_processes.")
# rename in dependency_graph
if self.dependency_graph:
self.dependency_graph[new_name] = self.dependency_graph.pop(old_name)
# replace old_name in dependency_graph values
for value in self.dependency_graph.values():
if old_name in value:
value[value.index(old_name)] = new_name
@model_validator(mode="after")
def order_processes(self) -> "Processing":
"""Ensure that processes are ordered by start_date_time"""
if not hasattr(self, "data_processes") or not self.data_processes:
return self
# Check if any processes are out of order
start_times = [process.start_date_time for process in self.data_processes]
if not all(start_times[i] <= start_times[i + 1] for i in range(len(start_times) - 1)):
# Sort processes by start_date_time
self.data_processes.sort(key=lambda x: x.start_date_time)
self.notes = (
"Processes were reordered by start_date_time"
if not self.notes
else f"{self.notes}; Processes were reordered by start_date_time"
)
return self
@classmethod
def create_with_sequential_process_graph(cls, data_processes: List[DataProcess], **kwargs) -> "Processing":
"""Generate a sequential process graph from a list of DataProcess objects"""
dependency_graph = {}
for i, process in enumerate(data_processes):
if i == 0:
dependency_graph[process.name] = []
else:
dependency_graph[process.name] = [data_processes[i - 1].name]
return cls(dependency_graph=dependency_graph, data_processes=data_processes, **kwargs)
@model_validator(mode="after")
def validate_process_graph(self):
"""Check that the same processes are represented in data_processes and dependency_graph"""
if not hasattr(self, "data_processes"): # bypass for testing
return self
# If the dependency_graph is None, then no need to validate
if self.dependency_graph is None:
return self
processes = set(self.process_names)
# Validate that all processes have a unique name
if len(processes) != len(self.data_processes):
raise ValueError("data_processes must have unique names.")
graph_processes = set(self.dependency_graph.keys())
missing_processes = processes - graph_processes
if missing_processes:
raise ValueError(
f"dependency_graph must include all processes in data_processes. Missing processes: {missing_processes}"
)
missing_processes = graph_processes - processes
if missing_processes:
raise ValueError(
f"data_processes must include all processes in dependency_graph. Missing processes: {missing_processes}"
)
return self
@model_validator(mode="after")
def validate_pipeline_names(self):
"""Ensure that all pipeline names in the processes are in the pipelines list"""
if not hasattr(self, "data_processes"): # bypass for testing
return self
pipeline_names = [pipeline.name for pipeline in self.pipelines] if self.pipelines else []
for process in self.data_processes:
if process.pipeline_name and process.pipeline_name not in pipeline_names:
raise ValueError(f"Pipeline name '{process.pipeline_name}' not found in pipelines list.")
return self
def __add__(self, other: "Processing") -> "Processing":
"""Combine two Processing objects"""
# Check for incompatible schema_version
if self.schema_version != other.schema_version:
raise ValueError("Cannot add Processing objects with different schema versions.")
# Copy self and other to avoid modifying in place
self = self.model_copy(deep=True)
other = other.model_copy(deep=True)
# Check and update repeated process names
repeated_processes = set(self.process_names) & set(other.process_names)
if repeated_processes:
warnings.warn(f"Processing objects have repeated processes: {repeated_processes}. Renaming duplicates.")
for name in sorted(repeated_processes):
# find base name if name is in the form of name_1, name_2, etc.
base_name = re.sub(r"_\d+$", "", name) # Remove existing numeric suffix
# Create a new unique name by incrementing the suffix
existing_names = set(self.process_names + other.process_names)
# Start with base name, try with incrementing suffixes until we find an unused name
new_name = name
i = 1
while new_name in existing_names:
new_name = f"{base_name}_{i}"
i += 1
other.rename_process(name, new_name)
merged_graph = merge_process_graph(
self.dependency_graph, other.dependency_graph, self.data_processes, other.data_processes
)
# link self's output to other's input
# note that this only makes sense if self has a single output process
# and other has a single input process
if merged_graph and len(self.data_processes) > 0 and len(other.data_processes) > 0:
merged_graph[other.data_processes[0].name] = [self.data_processes[-1].name]
return Processing(
pipelines=merge_optional_list(self.pipelines, other.pipelines),
data_processes=self.data_processes + other.data_processes,
dependency_graph=merged_graph,
notes=merge_notes(self.notes, other.notes),
)