-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcreate_batches.py
More file actions
324 lines (227 loc) · 10.6 KB
/
create_batches.py
File metadata and controls
324 lines (227 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import argparse
import math
import os
import random
import warnings
import astropy.units as u
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
from astropy.coordinates import FK5, SkyCoord
from astropy.io import fits
from astropy.nddata.utils import Cutout2D
from astropy.table import Table
from astropy.wcs import WCS
from functools import partial
from json import load,dump
from pathlib import Path
from tqdm import tqdm
np.random.seed(0)
random.seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('--catalogue','-c',required=True,type=str, default=None, help="A fits file containing the ids, ra, dec and tile_ids for all the sources you want in your dataset.")
parser.add_argument('--tile','-t',required=False,type=int, default=None, help="The specific tile you want to generate cutouts for sources within --catalogue. Creating batches from single tiles may impact training due to irregularities in tiles being transferred across whole batches and so is only recommended for inference." )
parser.add_argument('--batchsize','-bs',type=int, default=1024, help="The number of cutouts within each .npy file. If tiles are missing then .npy will contain at most --batchsize cutouts but most will contain less. This does not affect diffusion pipeline.")
parser.add_argument('--processes','-p',type=int, default=32, help="The number of separate processes to run to create .npy files. More will run faster. If --processes is larger than possible for your system, it will automatically reduce the number to a more suitable amount.")
parser.add_argument('--batch_dir','-bd',type=str, default="data", help="The directory where all your .npy files will be saved.")
parser.add_argument('--tile_dir','-td',type=str, default="tiles", help="The directory where all your .fits tiles are located.")
args = parser.parse_args()
print(args.batch_dir)
print(args.tile_dir)
os.makedirs(args.batch_dir, exist_ok=True)
cols_to_keep = ["object_id_euclid",
"tile_index_euclid",
"right_ascension_euclid",
"declination_euclid",
"flux_detection_total_euclid",
"fluxerr_detection_total_euclid",
"segmentation_area"]
def to_mag(F):
return -2.5*np.log10(F)+23.9
def get_tile_short_dict():
paths = Path(args.tile_dir).rglob('*.fits')
pathlist = [str(x) for x in paths]
# print(pathlist)
tile_short = {}
path_len = len(args.tile_dir)
for i in pathlist:
short = i[path_len+30:path_len+39]
tile_short[short] = i
return tile_short
tile_short = get_tile_short_dict()
print(args)
with fits.open(args.catalogue) as tab:
data = tab[1].data
data = Table(data)
print(f"\ndata loaded... {len(data)} rows\n")
def save_batches_from_tiles(cat_subset, split, chunk_id, total_chunks, batches_per_chunk):
image_errors = {}
missing_tile_ids = []
global_batch_offset = chunk_id * batches_per_chunk
batches_split = math.ceil(len(split)/args.batchsize)
source_ids = {}
if total_chunks > 1:
b_idx = global_batch_offset
else:
b_idx = 0
b = True
batch = None
source_ids[f"{b_idx}"] = []
if args.tile is not None:
with fits.open(tile_short[f"{args.tile}"]) as hdul:
data_t = hdul[0].data
header = hdul[0].header
for img_id in tqdm(split):
r = cat_subset[cat_subset["object_id_euclid"] == img_id]
if args.tile is None:
if str(r["tile_index_euclid"][0]) not in list(tile_short.keys()):
if not len(missing_tile_ids):
message = f'\nMissing tile {r["tile_index_euclid"][0]}, all missing tiles will be outputted in missing_tiles.txt after batching complete'
warnings.warn(message)
missing_tile_ids.append(str(r["tile_index_euclid"][0]))
continue
with fits.open(tile_short[str(r["tile_index_euclid"][0])]) as hdul:
data_t = hdul[0].data
header = hdul[0].header
ra, dec = (r["right_ascension_euclid"], r["declination_euclid"])
position = SkyCoord(ra*u.deg,dec*u.deg, frame=FK5)
size = 128*u.pix
wcs = WCS(header)
cutout = Cutout2D(data_t, position , size, wcs=wcs)
cutout = cutout.data.astype(np.float32)
if cutout.shape != (128,128):
image_errors[str(r["object_id_euclid"].value[0])] = "cant create 128x128 cutout"
continue
if (cutout==0).sum() > 0:
image_errors[str(r["object_id_euclid"].value[0])] = "0-value pixels in image"
continue
if b:
b = False
batch = np.zeros((128,128)).astype(np.float32)
batch[:cutout.shape[0],:cutout.shape[1]] = cutout
batch = batch[np.newaxis,:,:]
else:
t = np.zeros((128,128)).astype(np.float32)
t[:cutout.shape[0],:cutout.shape[1]] = cutout
batch = np.vstack((batch, t[np.newaxis,:,:]))
source_ids[f"{b_idx}"].append(str(r["object_id_euclid"].value[0]))
if len(batch) == args.batchsize:
with open(f'{args.batch_dir}/{b_idx}.npy', 'wb') as f:
np.save(f, batch)
if total_chunks > 1:
with open(f'{args.batch_dir}/batch_source_{chunk_id}.json', 'w') as f:
dump(source_ids, f, indent=2)
else:
with open(f'{args.batch_dir}/batch_source_full.json', 'w') as f:
dump(source_ids, f, indent=2)
b_idx += 1
if b_idx == batches_split*(chunk_id+1):
break
b = True
batch = None
source_ids[f"{b_idx}"] = []
else:
continue
with open(f'{args.batch_dir}/{b_idx}.npy', 'wb') as f:
np.save(f, batch)
# print(f"{chunk_id} saving {b_idx}")
if total_chunks > 1:
with open(f'{args.batch_dir}/batch_source_{chunk_id}.json', 'w') as f:
dump(source_ids, f, indent=2)
else:
with open(f'{args.batch_dir}/batch_source_full.json', 'w') as f:
dump(source_ids, f, indent=2)
return sum([len(source_ids[k]) for k in source_ids]), image_errors, missing_tile_ids
def process_chunk(chunk_id, total_chunks, sources_id):
total_rows = len(sources_id)
rows_per_chunk = (total_rows // args.batchsize) // total_chunks * args.batchsize
if total_chunks > 1:
start_idx = chunk_id * rows_per_chunk
if chunk_id == total_chunks - 1:
end_idx = total_rows
else:
end_idx = start_idx + rows_per_chunk
chunk_rows = end_idx - start_idx
assert chunk_rows % args.batchsize == 0, f"Chunk {chunk_id} has {chunk_rows} rows, not divisible by batchsize {args.batchsize}"
sources_id_ = sources_id[start_idx:end_idx].copy()
else:
sources_id_ = sources_id.copy()
batches_per_chunk = rows_per_chunk // args.batchsize
total_saved = save_batches_from_tiles(data, sources_id_, chunk_id, total_chunks,batches_per_chunk)
return total_saved
if args.tile is not None:
assert f"{args.tile}" in list(tile_short.keys()), f"{args.tile} not in {args.tile_dir}"
sources_id = [x["object_id_euclid"] for x in data if x["tile_index_euclid"] == args.tile]
assert len(sources_id), f"TILE {args.tile_dir} does not contain any sources from {args.catalogue}"
print(f"TILE {args.tile} contains {len(sources_id)} sources from {args.catalogue}\n")
else:
sources_id = [x for x in data["object_id_euclid"]]
random.shuffle(sources_id)
num_processes = args.processes
total_rows = len(sources_id)
rows_per_chunk = (total_rows // args.batchsize) // num_processes * args.batchsize
while rows_per_chunk < 1:
num_processes -= 1
rows_per_chunk = (total_rows // args.batchsize) // num_processes * args.batchsize
if num_processes != args.processes: print("num processes reduced to: ",num_processes)
with mp.Pool(processes=num_processes) as pool:
process_func = partial(process_chunk,
total_chunks=num_processes,
sources_id=sources_id)
chunk_ids = range(num_processes)
try:
all_results = pool.map(process_func, chunk_ids)
total_processed = sum(results[0] for results in all_results)
all_error_dicts = [results[1] for results in all_results]
missing_tile_ids = [results[2] for results in all_results]
missing_tile_ids = [y for x in missing_tile_ids for y in x]
image_errors = {}
for d in all_error_dicts:
image_errors.update(d)
print([results[0] for results in all_results])
print("\nAll processes completed successfully\n")
json_files_all = Path(args.batch_dir).rglob('*.json')
json_files = [str(x) for x in json_files_all if "_full" not in str(x)]
combined_dict = {}
for i in json_files:
temp_dict = {}
with open(i,'r') as fp:
temp_dict = load(fp)
for j in list(temp_dict.keys()):
assert j not in list(combined_dict.keys())
new_list = []
for k_idx, k in enumerate(temp_dict[j]):
new_list.append(str(k))
combined_dict[j] = new_list
with open(f'{args.batch_dir}/batch_source_full.json','w') as fp:
dump(combined_dict,fp, indent=2)
full_json_amount = sum([len(combined_dict[k]) for k in combined_dict])
assert full_json_amount == total_processed
for f in json_files:
os.remove(f)
except KeyboardInterrupt:
print("Interrupted by user")
pool.terminate()
pool.join()
except Exception as e:
print(f"Error occurred: {e}")
pool.terminate()
pool.join()
if len(missing_tile_ids):
missing_tile_ids = list(set(missing_tile_ids))
with open(f'missing_tiles.txt', 'w') as f:
for line in missing_tile_ids:
f.write(f"{line}\n")
tiles_message = f"All required tiles found!\n" if not len(missing_tile_ids) else f"Missing {len(missing_tile_ids)} tiles - See missing_tiles.txt for details.\n"
print(tiles_message)
missing = total_rows-total_processed
if missing>0:
errors = []
errors.append(f"ID, ERROR")
for k,v in image_errors.items():
errors.append(f"{k}, {v}")
with open(f'image_errors.csv', 'w') as f:
for line in errors:
f.write(f"{line}\n")
image_message = f"All sources saved successfully!\n" if not len(list(image_errors.keys())) else f"{len(sources_id) - total_processed} unsaved sources. See image_errors.csv for details.\n"
print(image_message)