Skip to content

Commit 6d1e7a3

Browse files
committed
fix pred_visualization
1 parent 9795acc commit 6d1e7a3

2 files changed

Lines changed: 66 additions & 25 deletions

File tree

scripts/data_visualization/pred_visualization.ipynb

Lines changed: 57 additions & 25 deletions
Large diffs are not rendered by default.

src/basicts/runners/basicts_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,7 @@ def _save_results(self, batch_idx: int, batch_data: Dict[str, torch.Tensor]) ->
10791079
inputs_path = os.path.join(save_dir, "inputs.npy")
10801080
pred_path = os.path.join(save_dir, "prediction.npy")
10811081
targets_path = os.path.join(save_dir, "targets.npy")
1082+
info_path = os.path.join(save_dir, "data_info.json")
10821083

10831084
# create memmap files
10841085
if batch_idx == 0:
@@ -1088,6 +1089,14 @@ def _save_results(self, batch_idx: int, batch_data: Dict[str, torch.Tensor]) ->
10881089
shape=(total_samples, *prediction.shape[1:]))
10891090
self._targets_memmap = np.memmap(targets_path, dtype=targets.dtype, mode="w+",
10901091
shape=(total_samples, *targets.shape[1:]))
1092+
# save shape and dtype info
1093+
info = {
1094+
"inputs": {"shape": self._inputs_memmap.shape, "dtype": str(self._inputs_memmap.dtype)},
1095+
"prediction": {"shape": self._prediction_memmap.shape, "dtype": str(self._prediction_memmap.dtype)},
1096+
"targets": {"shape": self._targets_memmap.shape, "dtype": str(self._targets_memmap.dtype)},
1097+
}
1098+
with open(info_path, "w") as f:
1099+
json.dump(info, f, indent=4)
10911100

10921101
start = batch_idx * inputs.shape[0]
10931102
end = start + inputs.shape[0]

0 commit comments

Comments
 (0)