Skip to content

Commit cd56ce8

Browse files
feat/merge calcjobs (#36)
* merge calcjobs * fixed entry-points * format
1 parent 04e7bac commit cd56ce8

3 files changed

Lines changed: 71 additions & 88 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ dependencies = [
2626
# [project.entry-points."aiida.data"]
2727
# "fans" = "aiida_fans.data:FANSParameters"
2828
[project.entry-points."aiida.calculations"]
29-
"fans.stashed" = "aiida_fans.calculations:FansStashedCalculation"
30-
"fans.fragmented" = "aiida_fans.calculations:FansFragmentedCalculation"
29+
"fans" = "aiida_fans.calculations:FansCalculation"
3130
[project.entry-points."aiida.parsers"]
3231
"fans" = "aiida_fans.parsers:FansParser"
3332
# [project.entry-points."aiida.cmdline.data"]

src/aiida_fans/calculations.py

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from aiida_fans.helpers import make_input_dict
1515

1616

17-
class FansCalcBase(CalcJob):
18-
"""Base class of all calculations using FANS."""
17+
class FansCalculation(CalcJob):
18+
"""Calculations using FANS."""
1919

2020
@classmethod
2121
def define(cls, spec: CalcJobProcessSpec) -> None:
@@ -39,6 +39,7 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
3939
# Custom Metadata
4040
spec.input("metadata.options.results_prefix", valid_type=str, default="")
4141
spec.input("metadata.options.results", valid_type=list, default=[])
42+
spec.input("metadata.options.stashed_microstructure", valid_type=bool, default=True)
4243

4344
# Input Ports
4445
## Microstructure Definition
@@ -69,6 +70,40 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
6970

7071
def prepare_for_submission(self, folder: Folder) -> CalcInfo:
7172
"""Prepare the calculation for submission."""
73+
# Stashed Strategy:
74+
if self.options.stashed_microstructure:
75+
ms_filepath: Path = Path(self.inputs.code.computer.get_workdir()) / \
76+
"stash/microstructures" / \
77+
self.inputs.microstructure.file.filename
78+
# if microstructure does not exist in stash, make it
79+
if not ms_filepath.is_file():
80+
ms_filepath.parent.mkdir(parents=True, exist_ok=True)
81+
with self.inputs.microstructure.file.open(mode='rb') as source:
82+
with ms_filepath.open(mode='wb') as target:
83+
copyfileobj(source, target)
84+
85+
# input.json as dict
86+
input_dict = make_input_dict(self)
87+
input_dict["microstructure"]["filepath"] = str(ms_filepath)
88+
# write input.json to working directory
89+
with folder.open(self.options.input_filename, "w", "utf8") as json:
90+
dump(input_dict, json, indent=4)
91+
# Fragmented Strategy:
92+
else:
93+
datasetname : str = self.inputs.microstructure.datasetname.value
94+
with folder.open("microstructure.h5","bw") as f_dest:
95+
with h5File(f_dest,"w") as h5_dest:
96+
with self.inputs.microstructure.file.open(mode="rb") as f_src:
97+
with h5File(f_src,'r') as h5_src:
98+
h5_src.copy(datasetname, h5_dest, name=datasetname)
99+
100+
# input.json as dict
101+
input_dict = make_input_dict(self)
102+
input_dict["microstructure"]["filepath"] = "microstructure.h5"
103+
# write input.json to working directory
104+
with folder.open(self.options.input_filename, "w", "utf8") as json:
105+
dump(input_dict, json, indent=4)
106+
72107
# Specifying the code info:
73108
codeinfo = CodeInfo()
74109
codeinfo.code_uuid = self.inputs.code.uuid
@@ -87,60 +122,3 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
87122
]
88123

89124
return calcinfo
90-
91-
92-
class FansStashedCalculation(FansCalcBase):
93-
"""Calculations using FANS and the "Stashed" microstructure distribution strategy."""
94-
95-
@classmethod
96-
def define(cls, spec: CalcJobProcessSpec) -> None:
97-
"""Define inputs, outputs, and exit codes of the calculation."""
98-
return super().define(spec)
99-
100-
def prepare_for_submission(self, folder: Folder) -> CalcInfo:
101-
"""Prepare the calculation for submission."""
102-
ms_filepath: Path = Path(self.inputs.code.computer.get_workdir()) / \
103-
"stash/microstructures" / \
104-
self.inputs.microstructure.file.filename
105-
# if microstructure does not exist in stash, make it
106-
if not ms_filepath.is_file():
107-
ms_filepath.parent.mkdir(parents=True, exist_ok=True)
108-
with self.inputs.microstructure.file.open(mode='rb') as source:
109-
with ms_filepath.open(mode='wb') as target:
110-
copyfileobj(source, target)
111-
112-
# input.json as dict
113-
input_dict = make_input_dict(self)
114-
input_dict["microstructure"]["filepath"] = str(ms_filepath)
115-
# write input.json to working directory
116-
with folder.open(self.options.input_filename, "w", "utf8") as json:
117-
dump(input_dict, json, indent=4)
118-
119-
return super().prepare_for_submission(folder)
120-
121-
class FansFragmentedCalculation(FansCalcBase):
122-
"""Calculations using FANS and the "Fragmented" microstructure distribution strategy."""
123-
124-
@classmethod
125-
def define(cls, spec: CalcJobProcessSpec) -> None:
126-
"""Define inputs, outputs, and exit codes of the calculation."""
127-
return super().define(spec)
128-
129-
def prepare_for_submission(self, folder: Folder) -> CalcInfo:
130-
"""Prepare the calculation for submission."""
131-
# Write Microstructure Subset to Folder
132-
datasetname : str = self.inputs.microstructure.datasetname.value
133-
with folder.open("microstructure.h5","bw") as f_dest:
134-
with h5File(f_dest,"w") as h5_dest:
135-
with self.inputs.microstructure.file.open(mode="rb") as f_src:
136-
with h5File(f_src,'r') as h5_src:
137-
h5_src.copy(datasetname, h5_dest, name=datasetname)
138-
139-
# input.json as dict
140-
input_dict = make_input_dict(self)
141-
input_dict["microstructure"]["filepath"] = "microstructure.h5"
142-
# write input.json to working directory
143-
with folder.open(self.options.input_filename, "w", "utf8") as json:
144-
dump(input_dict, json, indent=4)
145-
146-
return super().prepare_for_submission(folder)

src/aiida_fans/utils.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,22 @@ def aiida_type(value: Any) -> type[Data]:
2424
"""
2525
match value:
2626
case str():
27-
return DataFactory("core.str") # Str
27+
return DataFactory("core.str") # Str
2828
case int():
29-
return DataFactory("core.int") # Int
29+
return DataFactory("core.int") # Int
3030
case float():
31-
return DataFactory("core.float") # Float
31+
return DataFactory("core.float") # Float
3232
case list():
33-
return DataFactory("core.list") # List
33+
return DataFactory("core.list") # List
3434
case dict():
3535
if all(map(lambda t: isinstance(t, ndarray), value.values())):
36-
return DataFactory("core.array") # ArrayData
36+
return DataFactory("core.array") # ArrayData
3737
else:
38-
return DataFactory("core.dict") # Dict
38+
return DataFactory("core.dict") # Dict
3939
case _:
4040
raise NotImplementedError(f"Received an input of value: {value} with type: {type(value)}")
4141

42+
4243
def fetch(label: str, value: Any) -> list[Node]:
4344
"""Return a list of nodes matching the label and value provided.
4445
@@ -50,26 +51,31 @@ def fetch(label: str, value: Any) -> list[Node]:
5051
list[Node]: the list of nodes matching the give criteria
5152
"""
5253
datatype = aiida_type(value)
53-
nodes = QueryBuilder(
54-
).append(cls=datatype, tag="n"
55-
).add_filter("n", {"label": label}
56-
).add_filter("n", {"attributes": {"==": datatype(value).base.attributes.all}}
57-
).all(flat=True)
54+
nodes = (
55+
QueryBuilder()
56+
.append(cls=datatype, tag="n")
57+
.add_filter("n", {"label": label})
58+
.add_filter("n", {"attributes": {"==": datatype(value).base.attributes.all}})
59+
.all(flat=True)
60+
)
5861

5962
if datatype != DataFactory("core.array"):
60-
return nodes # type: ignore
63+
return nodes # type: ignore
6164
else:
6265
array_nodes = []
6366
for array_node in nodes:
6467
array_value = {
65-
k: v for k, v in [
66-
(name, array_node.get_array(name)) for name in array_node.get_arraynames() # type: ignore
68+
k: v
69+
for k, v in [
70+
(name, array_node.get_array(name))
71+
for name in array_node.get_arraynames() # type: ignore
6772
]
6873
}
6974
if arraydata_equal(value, array_value):
7075
array_nodes.append(array_node)
7176
return array_nodes
7277

78+
7379
def generate(label: str, value: Any) -> Node:
7480
"""Return a single node with the label and value provided.
7581
@@ -93,6 +99,7 @@ def generate(label: str, value: Any) -> Node:
9399
else:
94100
raise RuntimeError
95101

102+
96103
def convert(ins: dict[str, Any], path: list[str] = []):
97104
"""Takes a dictionary of inputs and converts the values to their respective Nodes.
98105
@@ -108,7 +115,8 @@ def convert(ins: dict[str, Any], path: list[str] = []):
108115
else:
109116
ins[k] = generate(".".join([*path, k]), v)
110117

111-
def compile_query(ins: dict[str,Any], qb: QueryBuilder) -> None:
118+
119+
def compile_query(ins: dict[str, Any], qb: QueryBuilder) -> None:
112120
"""Interate over the converted input dictionary and append to the QueryBuilder for each node.
113121
114122
Args:
@@ -121,18 +129,14 @@ def compile_query(ins: dict[str,Any], qb: QueryBuilder) -> None:
121129
if k in ["microstructure", "error_parameters"] and isinstance(v, dict):
122130
compile_query(v, qb)
123131
else:
124-
qb.append(
125-
cls=type(v),
126-
with_outgoing="calc",
127-
filters={"pk": v.pk}
128-
)
132+
qb.append(cls=type(v), with_outgoing="calc", filters={"pk": v.pk})
129133

130134

131135
def execute_fans(
132-
mode: Literal["Submit", "Run"],
133-
inputs: dict[str, Any],
134-
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",
135-
):
136+
mode: Literal["Submit", "Run"],
137+
inputs: dict[str, Any],
138+
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",
139+
):
136140
"""This utility function simplifies the process of executing aiida-fans jobs.
137141
138142
The only nodes you must provide are the `code` and `microstructure` inputs.
@@ -191,17 +195,18 @@ def execute_fans(
191195
compile_query(inputs, qb)
192196
results = qb.all(flat=True)
193197
if (count := len(results)) != 0:
194-
print(f"It seems this calculation has already been performed {count} time{"s" if count > 1 else ""}. {results}")
198+
print(f"It seems this calculation has already been performed {count} time{'s' if count > 1 else ''}. {results}")
195199
confirmation = input("Are you sure you want to rerun it? [y/N] ").strip().lower() in ["y", "yes"]
196200
else:
197201
confirmation = True
198202

199203
if confirmation:
200204
match mode:
201205
case "Run":
202-
run(calcjob, inputs) # type: ignore
206+
run(calcjob, inputs) # type: ignore
203207
case "Submit":
204-
submit(calcjob, inputs) # type: ignore
208+
submit(calcjob, inputs) # type: ignore
209+
205210

206211
def submit_fans(
207212
inputs: dict[str, Any],
@@ -210,6 +215,7 @@ def submit_fans(
210215
"""See `execute_fans` for implementation and usage details."""
211216
execute_fans("Submit", inputs, strategy)
212217

218+
213219
def run_fans(
214220
inputs: dict[str, Any],
215221
strategy: Literal["Fragmented", "Stashed"] = "Fragmented",

0 commit comments

Comments
 (0)