1+ import os , sys
2+ import h5py
3+ import torch
4+ import argparse
5+ import numpy as np
6+ from torch import package
7+ from datasets import load_dataset
8+ from torchvision .transforms import CenterCrop , Compose , ToTensor
9+ from multiprocessing import Pool , current_process
10+ from tqdm import tqdm
11+
12+ # Set up dataset
13+ crop = CenterCrop (144 )
14+ RGB_SCALES = {
15+ "u" : (2 , 1.5 ),
16+ "g" : (2 , 6.0 ),
17+ "r" : (1 , 3.4 ),
18+ "i" : (0 , 1.0 ),
19+ "z" : (0 , 2.2 ),
20+ }
21+
22+
23+ def decals_to_rgb (image , bands = ["g" , "r" , "z" ], scales = None , m = 0.03 , Q = 20.0 ):
24+ axes , scales = zip (* [RGB_SCALES [bands [i ]] for i in range (len (bands ))])
25+ scales = [scales [i ] for i in axes ]
26+ image = image .movedim (1 , - 1 ).flip ( - 1 )
27+ scales = torch .tensor (scales , dtype = torch .float32 ).to (image .device )
28+ I = torch .sum (torch .clamp (image * scales + m , min = 0 ), dim = - 1 ) / len (bands )
29+ fI = torch .arcsinh (Q * I ) / np .sqrt (Q )
30+ I += (I == 0.0 ) * 1e-6
31+ image = (image * scales + m ) * (fI / I ).unsqueeze (- 1 )
32+ image = torch .clamp (image , 0 , 1 )
33+ return image .movedim (- 1 , 1 )
34+
35+
36+
37+ def import_package (path : str , device : str = "cpu" ) -> torch .nn .Module :
38+ """Import a torch package from a given path"""
39+ importer = package .PackageImporter (path )
40+ model = importer .load_pickle ("network" , "network.pkl" , map_location = device )
41+ return model
42+
43+
44+ def process_file (args ) -> None :
45+ """Process a single file in the dataset"""
46+ file , save_dir , batch_size , gpu_id = args
47+ file_path = os .path .join (dset_root , file , '001-of-001.hdf5' )
48+
49+ # Set the GPU device for this process
50+ torch .cuda .set_device (gpu_id )
51+
52+ # Load the model
53+ astrodino = import_package ("/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt" ).to (torch .device (f'cuda:{ gpu_id } ' ))
54+
55+ embeddings = []
56+ with h5py .File (file_path , 'r' ) as f :
57+ img_batch = []
58+ for img in tqdm (f ['image_array' ]):
59+ # Convert to RGB
60+ img = crop (torch .tensor (img [[0 ,1 ,3 ]])) # get g,r,z
61+
62+ # Append to batch
63+ img_batch .append (img )
64+
65+ if len (img_batch ) == batch_size :
66+ with torch .no_grad ():
67+ images = torch .stack (img_batch ).cuda ()
68+ images = decals_to_rgb (images )
69+ emb = astrodino (images )
70+ embeddings .append (emb .cpu ().numpy ())
71+ im_batch = []
72+
73+ # Get ra, dec, obj_id
74+ ra = f ['RA' ][:]
75+ dec = f ['DEC' ][:]
76+ obj_id = f ['object_id' ][:]
77+
78+ # Concatenate embeddings
79+ embeddings = np .concatenate (embeddings , axis = 0 )
80+
81+ # Save embeddings
82+ save_dir = os .path .join (save_dir , file )
83+ if not os .path .exists (save_dir ):
84+ os .makedirs (save_dir )
85+
86+ save_path = os .path .join (save_dir , '001-of-001.hdf5' )
87+ with h5py .File (save_path , 'w' ) as f :
88+ f .create_dataset ('embeddings' , data = embeddings )
89+ f .create_dataset ('RA' , data = ra )
90+ f .create_dataset ('DEC' , data = dec )
91+ f .create_dataset ('object_id' , data = obj_id )
92+
93+
94+ def embed_legacysurvey (
95+ dset_root : str ,
96+ save_dir : str ,
97+ astrodino_dir : str ,
98+ batch_size = 512 ,
99+ num_gpus = 4
100+ ):
101+ # List all files in the dataset
102+ files = os .listdir (dset_root )
103+
104+ # Create arguments for each process
105+ args = [(f , save_dir , batch_size , i % num_gpus ) for i , f in enumerate (files )]
106+
107+ # Use multiprocessing to process files in parallel
108+ with Pool (processes = num_gpus ) as pool :
109+ pool .map (process_file , args )
110+
111+ if __name__ == '__main__' :
112+ parser = argparse .ArgumentParser ()
113+ parser .add_argument ('--dset_root' , type = str , required = True )
114+ parser .add_argument ('--save_dir' , type = str , required = True )
115+ parser .add_argument ('--astrodino_dir' , type = str , default = "/mnt/ceph/users/polymathic/astroclip/pretrained" )
116+ parser .add_argument ('--batch_size' , type = int , default = 512 )
117+ parser .add_argument ('--num_gpus' , type = int , default = 4 )
118+ args = parser .parse_args ()
119+
120+ # Run the embedding process
121+ embed_legacysurvey (dset_root , save_dir , astrodino_dir , batch_size , num_gpus )
0 commit comments