1- import os , sys
2- import h5py
3- import torch
41import argparse
2+ import os
3+ from multiprocessing import Pool
4+
5+ import h5py
56import numpy as np
6- from torch import package
7+ import torch
78from datasets import load_dataset
9+ from torch import package
810from torchvision .transforms import CenterCrop , Compose , ToTensor
9- from multiprocessing import Pool , current_process
1011from tqdm import tqdm
1112
1213# Set up dataset
2324def decals_to_rgb (image , bands = ["g" , "r" , "z" ], scales = None , m = 0.03 , Q = 20.0 ):
2425 axes , scales = zip (* [RGB_SCALES [bands [i ]] for i in range (len (bands ))])
2526 scales = [scales [i ] for i in axes ]
26- image = image .movedim (1 , - 1 ).flip ( - 1 )
27+ image = image .movedim (1 , - 1 ).flip (- 1 )
2728 scales = torch .tensor (scales , dtype = torch .float32 ).to (image .device )
2829 I = torch .sum (torch .clamp (image * scales + m , min = 0 ), dim = - 1 ) / len (bands )
2930 fI = torch .arcsinh (Q * I ) / np .sqrt (Q )
@@ -33,7 +34,6 @@ def decals_to_rgb(image, bands=["g", "r", "z"], scales=None, m=0.03, Q=20.0):
3334 return image .movedim (- 1 , 1 )
3435
3536
36-
3737def import_package (path : str , device : str = "cpu" ) -> torch .nn .Module :
3838 """Import a torch package from a given path"""
3939 importer = package .PackageImporter (path )
@@ -44,20 +44,22 @@ def import_package(path: str, device: str = "cpu") -> torch.nn.Module:
4444def process_file (args ) -> None :
4545 """Process a single file in the dataset"""
4646 file , save_dir , batch_size , gpu_id = args
47- file_path = os .path .join (dset_root , file , ' 001-of-001.hdf5' )
47+ file_path = os .path .join (dset_root , file , " 001-of-001.hdf5" )
4848
4949 # Set the GPU device for this process
5050 torch .cuda .set_device (gpu_id )
5151
5252 # Load the model
53- astrodino = import_package ("/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt" ).to (torch .device (f'cuda:{ gpu_id } ' ))
53+ astrodino = import_package (
54+ "/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt"
55+ ).to (torch .device (f"cuda:{ gpu_id } " ))
5456
5557 embeddings = []
56- with h5py .File (file_path , 'r' ) as f :
58+ with h5py .File (file_path , "r" ) as f :
5759 img_batch = []
58- for img in tqdm (f [' image_array' ]):
60+ for img in tqdm (f [" image_array" ]):
5961 # Convert to RGB
60- img = crop (torch .tensor (img [[0 ,1 , 3 ]])) # get g,r,z
62+ img = crop (torch .tensor (img [[0 , 1 , 3 ]])) # get g,r,z
6163
6264 # Append to batch
6365 img_batch .append (img )
@@ -71,9 +73,9 @@ def process_file(args) -> None:
7173 im_batch = []
7274
7375 # Get ra, dec, obj_id
74- ra = f ['RA' ][:]
75- dec = f [' DEC' ][:]
76- obj_id = f [' object_id' ][:]
76+ ra = f ["RA" ][:]
77+ dec = f [" DEC" ][:]
78+ obj_id = f [" object_id" ][:]
7779
7880 # Concatenate embeddings
7981 embeddings = np .concatenate (embeddings , axis = 0 )
@@ -83,39 +85,40 @@ def process_file(args) -> None:
8385 if not os .path .exists (save_dir ):
8486 os .makedirs (save_dir )
8587
86- save_path = os .path .join (save_dir , ' 001-of-001.hdf5' )
87- with h5py .File (save_path , 'w' ) as f :
88- f .create_dataset (' embeddings' , data = embeddings )
89- f .create_dataset ('RA' , data = ra )
90- f .create_dataset (' DEC' , data = dec )
91- f .create_dataset (' object_id' , data = obj_id )
88+ save_path = os .path .join (save_dir , " 001-of-001.hdf5" )
89+ with h5py .File (save_path , "w" ) as f :
90+ f .create_dataset (" embeddings" , data = embeddings )
91+ f .create_dataset ("RA" , data = ra )
92+ f .create_dataset (" DEC" , data = dec )
93+ f .create_dataset (" object_id" , data = obj_id )
9294
9395
9496def embed_legacysurvey (
95- dset_root : str ,
96- save_dir : str ,
97- astrodino_dir : str ,
98- batch_size = 512 ,
99- num_gpus = 4
97+ dset_root : str , save_dir : str , astrodino_dir : str , batch_size = 512 , num_gpus = 4
10098):
10199 # List all files in the dataset
102100 files = os .listdir (dset_root )
103101
104102 # Create arguments for each process
105103 args = [(f , save_dir , batch_size , i % num_gpus ) for i , f in enumerate (files )]
106-
104+
107105 # Use multiprocessing to process files in parallel
108106 with Pool (processes = num_gpus ) as pool :
109107 pool .map (process_file , args )
110-
111- if __name__ == '__main__' :
108+
109+
110+ if __name__ == "__main__" :
112111 parser = argparse .ArgumentParser ()
113- parser .add_argument ('--dset_root' , type = str , required = True )
114- parser .add_argument ('--save_dir' , type = str , required = True )
115- parser .add_argument ('--astrodino_dir' , type = str , default = "/mnt/ceph/users/polymathic/astroclip/pretrained" )
116- parser .add_argument ('--batch_size' , type = int , default = 512 )
117- parser .add_argument ('--num_gpus' , type = int , default = 4 )
112+ parser .add_argument ("--dset_root" , type = str , required = True )
113+ parser .add_argument ("--save_dir" , type = str , required = True )
114+ parser .add_argument (
115+ "--astrodino_dir" ,
116+ type = str ,
117+ default = "/mnt/ceph/users/polymathic/astroclip/pretrained" ,
118+ )
119+ parser .add_argument ("--batch_size" , type = int , default = 512 )
120+ parser .add_argument ("--num_gpus" , type = int , default = 4 )
118121 args = parser .parse_args ()
119122
120123 # Run the embedding process
121- embed_legacysurvey (dset_root , save_dir , astrodino_dir , batch_size , num_gpus )
124+ embed_legacysurvey (dset_root , save_dir , astrodino_dir , batch_size , num_gpus )
0 commit comments