Skip to content

Commit 812aaf9

Browse files
committed
Make save_auxiliary_data recursiable
1 parent 0f255f7 commit 812aaf9

File tree

1 file changed

+46
-24
lines changed

1 file changed

+46
-24
lines changed

varipeps/peps/unitcell.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,8 +1088,10 @@ def save_to_file(
10881088

10891089
self.save_auxiliary_data(grp_aux, auxiliary_data)
10901090

1091-
@staticmethod
1092-
def save_auxiliary_data(grp: h5py.Group, auxiliary_data: Optional[Dict[str, Any]]):
1091+
@classmethod
1092+
def save_auxiliary_data(
1093+
cls, grp: h5py.Group, auxiliary_data: Optional[Dict[str, Any]]
1094+
):
10931095
"""
10941096
Save auxiliary data to HDF5 group.
10951097
@@ -1121,6 +1123,9 @@ def save_auxiliary_data(grp: h5py.Group, auxiliary_data: Optional[Dict[str, Any]
11211123
compression="gzip",
11221124
compression_opts=6,
11231125
)
1126+
elif isinstance(val, collections.abc.Mapping):
1127+
inner_grp = grp.create_group(key)
1128+
cls.save_auxiliary_data(inner_grp, val)
11241129
else:
11251130
grp.attrs[key] = val
11261131

@@ -1177,6 +1182,7 @@ def save_to_group(self, grp: h5py.Group, store_config: bool = True) -> None:
11771182
def load_from_file(
11781183
cls: Type[T_PEPS_Unit_Cell],
11791184
path: PathLike,
1185+
return_unitcell: bool = True,
11801186
return_config: bool = False,
11811187
return_auxiliary_data: bool = False,
11821188
) -> Union[
@@ -1191,6 +1197,8 @@ def load_from_file(
11911197
Args:
11921198
path (:obj:`os.PathLike`):
11931199
Path of the HDF5 file.
1200+
return_unitcell (:obj:`bool`):
1201+
Return the PEPS unit cell.
11941202
return_config (:obj:`bool`):
11951203
Return a config object initialized with the values from the HDF5
11961204
files. If no config is stored in the file, just the data is returned.
@@ -1201,22 +1209,25 @@ def load_from_file(
12011209
should be stored along the other data in the file.
12021210
"""
12031211
with h5py.File(path, "r") as f:
1204-
out = cls.load_from_group(f["unitcell"], return_config)
1205-
1206-
auxiliary_data = {}
1207-
if (auxiliary_data_grp := f.get("auxiliary_data")) is not None:
1208-
auxiliary_data = cls.load_auxiliary_data(auxiliary_data_grp)
1209-
elif (max_trunc_error_list := f.get("max_trunc_error_list")) is not None:
1210-
auxiliary_data["max_trunc_error_list"] = jnp.asarray(
1211-
max_trunc_error_list
1212-
)
1212+
out = cls.load_from_group(f["unitcell"], return_unitcell, return_config)
1213+
1214+
if return_auxiliary_data:
1215+
auxiliary_data = {}
1216+
if (auxiliary_data_grp := f.get("auxiliary_data")) is not None:
1217+
auxiliary_data = cls.load_auxiliary_data(auxiliary_data_grp)
1218+
elif (
1219+
max_trunc_error_list := f.get("max_trunc_error_list")
1220+
) is not None:
1221+
auxiliary_data["max_trunc_error_list"] = jnp.asarray(
1222+
max_trunc_error_list
1223+
)
12131224

1214-
if return_config and return_auxiliary_data:
1215-
return out[0], out[1], auxiliary_data
1216-
elif return_config:
1217-
return out[0], out[1]
1218-
elif return_auxiliary_data:
1219-
return out, auxiliary_data
1225+
if out is None:
1226+
out = auxiliary_data
1227+
elif isinstance(out, tuple):
1228+
out = out + (auxiliary_data,)
1229+
else:
1230+
out = (out, auxiliary_data)
12201231

12211232
return out
12221233

@@ -1241,7 +1252,10 @@ def load_auxiliary_data(grp: h5py.Group):
12411252

12421253
@classmethod
12431254
def load_from_group(
1244-
cls: Type[T_PEPS_Unit_Cell], grp: h5py.Group, return_config: bool = False
1255+
cls: Type[T_PEPS_Unit_Cell],
1256+
grp: h5py.Group,
1257+
return_unitcell: bool = True,
1258+
return_config: bool = False,
12451259
) -> Union[
12461260
T_PEPS_Unit_Cell, Tuple[T_PEPS_Unit_Cell, varipeps.config.VariPEPS_Config]
12471261
]:
@@ -1251,15 +1265,20 @@ def load_from_group(
12511265
Args:
12521266
grp (:obj:`h5py.Group`):
12531267
HDF5 group object to load the data from.
1268+
return_unitcell (:obj:`bool`):
1269+
Return the PEPS unit cell.
12541270
return_config (:obj:`bool`):
12551271
Return a config object initialized with the values from the HDF5
12561272
files. If no config is stored in the file, just the data is returned.
12571273
Missing config flags in the file uses the default values from the
12581274
config object.
12591275
"""
1260-
data = cls.Unit_Cell_Data.load_from_group(grp["data"])
1261-
real_ix = int(grp.attrs["real_ix"])
1262-
real_iy = int(grp.attrs["real_iy"])
1276+
if return_unitcell:
1277+
data = cls.Unit_Cell_Data.load_from_group(grp["data"])
1278+
real_ix = int(grp.attrs["real_ix"])
1279+
real_iy = int(grp.attrs["real_iy"])
1280+
elif not return_config:
1281+
return None
12631282

12641283
if return_config:
12651284
if grp.get("config") is None:
@@ -1293,9 +1312,12 @@ def load_from_group(
12931312
config_dict["slurm_restart_mode"]
12941313
)
12951314

1296-
return cls(
1297-
data=data, real_ix=real_ix, real_iy=real_iy
1298-
), varipeps.config.VariPEPS_Config(**config_dict)
1315+
if return_unitcell:
1316+
return cls(
1317+
data=data, real_ix=real_ix, real_iy=real_iy
1318+
), varipeps.config.VariPEPS_Config(**config_dict)
1319+
else:
1320+
return varipeps.config.VariPEPS_Config(**config_dict)
12991321

13001322
return cls(data=data, real_ix=real_ix, real_iy=real_iy)
13011323

0 commit comments

Comments
 (0)