@@ -1079,6 +1079,7 @@ def _save_results(self, batch_idx: int, batch_data: Dict[str, torch.Tensor]) ->
10791079 inputs_path = os .path .join (save_dir , "inputs.npy" )
10801080 pred_path = os .path .join (save_dir , "prediction.npy" )
10811081 targets_path = os .path .join (save_dir , "targets.npy" )
1082+ info_path = os .path .join (save_dir , "data_info.json" )
10821083
10831084 # create memmap files
10841085 if batch_idx == 0 :
@@ -1088,6 +1089,14 @@ def _save_results(self, batch_idx: int, batch_data: Dict[str, torch.Tensor]) ->
10881089 shape = (total_samples , * prediction .shape [1 :]))
10891090 self ._targets_memmap = np .memmap (targets_path , dtype = targets .dtype , mode = "w+" ,
10901091 shape = (total_samples , * targets .shape [1 :]))
1092+ # save shape and dtype info
1093+ info = {
1094+ "inputs" : {"shape" : self ._inputs_memmap .shape , "dtype" : str (self ._inputs_memmap .dtype )},
1095+ "prediction" : {"shape" : self ._prediction_memmap .shape , "dtype" : str (self ._prediction_memmap .dtype )},
1096+ "targets" : {"shape" : self ._targets_memmap .shape , "dtype" : str (self ._targets_memmap .dtype )},
1097+ }
1098+ with open (info_path , "w" ) as f :
1099+ json .dump (info , f , indent = 4 )
10911100
10921101 start = batch_idx * inputs .shape [0 ]
10931102 end = start + inputs .shape [0 ]
0 commit comments