Skip to content

Commit 4c31229

Browse files
committed
embedding legacy survey with dino
1 parent 8cec8f6 commit 4c31229

2 files changed

Lines changed: 147 additions & 0 deletions

File tree

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os, sys
2+
import h5py
3+
import torch
4+
import argparse
5+
import numpy as np
6+
from torch import package
7+
from datasets import load_dataset
8+
from torchvision.transforms import CenterCrop, Compose, ToTensor
9+
from multiprocessing import Pool, current_process
10+
from tqdm import tqdm
11+
12+
# Set up dataset
13+
crop = CenterCrop(144)
14+
RGB_SCALES = {
15+
"u": (2, 1.5),
16+
"g": (2, 6.0),
17+
"r": (1, 3.4),
18+
"i": (0, 1.0),
19+
"z": (0, 2.2),
20+
}
21+
22+
23+
def decals_to_rgb(image, bands=["g", "r", "z"], scales=None, m=0.03, Q=20.0):
24+
axes, scales = zip(*[RGB_SCALES[bands[i]] for i in range(len(bands))])
25+
scales = [scales[i] for i in axes]
26+
image = image.movedim(1, -1).flip( -1)
27+
scales = torch.tensor(scales, dtype=torch.float32).to(image.device)
28+
I = torch.sum(torch.clamp(image * scales + m, min=0), dim=-1) / len(bands)
29+
fI = torch.arcsinh(Q * I) / np.sqrt(Q)
30+
I += (I == 0.0) * 1e-6
31+
image = (image * scales + m) * (fI / I).unsqueeze(-1)
32+
image = torch.clamp(image, 0, 1)
33+
return image.movedim(-1, 1)
34+
35+
36+
37+
def import_package(path: str, device: str = "cpu") -> torch.nn.Module:
38+
"""Import a torch package from a given path"""
39+
importer = package.PackageImporter(path)
40+
model = importer.load_pickle("network", "network.pkl", map_location=device)
41+
return model
42+
43+
44+
def process_file(args) -> None:
45+
"""Process a single file in the dataset"""
46+
file, save_dir, batch_size, gpu_id = args
47+
file_path = os.path.join(dset_root, file, '001-of-001.hdf5')
48+
49+
# Set the GPU device for this process
50+
torch.cuda.set_device(gpu_id)
51+
52+
# Load the model
53+
astrodino = import_package("/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt").to(torch.device(f'cuda:{gpu_id}'))
54+
55+
embeddings = []
56+
with h5py.File(file_path, 'r') as f:
57+
img_batch = []
58+
for img in tqdm(f['image_array']):
59+
# Convert to RGB
60+
img = crop(torch.tensor(img[[0,1,3]])) # get g,r,z
61+
62+
# Append to batch
63+
img_batch.append(img)
64+
65+
if len(img_batch) == batch_size:
66+
with torch.no_grad():
67+
images = torch.stack(img_batch).cuda()
68+
images = decals_to_rgb(images)
69+
emb = astrodino(images)
70+
embeddings.append(emb.cpu().numpy())
71+
im_batch = []
72+
73+
# Get ra, dec, obj_id
74+
ra = f['RA'][:]
75+
dec = f['DEC'][:]
76+
obj_id = f['object_id'][:]
77+
78+
# Concatenate embeddings
79+
embeddings = np.concatenate(embeddings, axis=0)
80+
81+
# Save embeddings
82+
save_dir = os.path.join(save_dir, file)
83+
if not os.path.exists(save_dir):
84+
os.makedirs(save_dir)
85+
86+
save_path = os.path.join(save_dir, '001-of-001.hdf5')
87+
with h5py.File(save_path, 'w') as f:
88+
f.create_dataset('embeddings', data=embeddings)
89+
f.create_dataset('RA', data=ra)
90+
f.create_dataset('DEC', data=dec)
91+
f.create_dataset('object_id', data=obj_id)
92+
93+
94+
def embed_legacysurvey(
95+
dset_root: str,
96+
save_dir: str,
97+
astrodino_dir: str,
98+
batch_size=512,
99+
num_gpus=4
100+
):
101+
# List all files in the dataset
102+
files = os.listdir(dset_root)
103+
104+
# Create arguments for each process
105+
args = [(f, save_dir, batch_size, i % num_gpus) for i, f in enumerate(files)]
106+
107+
# Use multiprocessing to process files in parallel
108+
with Pool(processes=num_gpus) as pool:
109+
pool.map(process_file, args)
110+
111+
if __name__ == '__main__':
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument('--dset_root', type=str, required=True)
114+
parser.add_argument('--save_dir', type=str, required=True)
115+
parser.add_argument('--astrodino_dir', type=str, default="/mnt/ceph/users/polymathic/astroclip/pretrained")
116+
parser.add_argument('--batch_size', type=int, default=512)
117+
parser.add_argument('--num_gpus', type=int, default=4)
118+
args = parser.parse_args()
119+
120+
# Run the embedding process
121+
embed_legacysurvey(dset_root, save_dir, astrodino_dir, batch_size, num_gpus)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash -l
2+
3+
#SBATCH -p gpu
4+
#SBATCH -N 1
5+
#SBATCH -C a100-80gb
6+
#SBATCH --ntasks-per-node=4
7+
#SBATCH --gpus-per-node=4
8+
#SBATCH --cpus-per-gpu=1
9+
#SBATCH -t 168:00:00
10+
#SBATCH --output=logs/out-%j.log
11+
#SBATCH -J "embedding"
12+
13+
module purge
14+
module load gcc
15+
16+
$dset_root = "/mnt/ceph/users/polymathic/MultimodalUniverse/legacysurvey/dr10_south_21"
17+
$save_root = "/mnt/ceph/users/polymathic/MultimodalUniverse/astrodino_legacysurvey"
18+
19+
export OMP_NUM_THREADS=${SLURM_CPUS_ON_NODE}
20+
21+
# enable logging
22+
export CUDA_LAUNCH_BLOCKING=1.
23+
24+
source /mnt/home/lparker/python_envs/toto/bin/activate
25+
26+
python launch_embeddings.py --dset_root $dset_root --save_root $save_root --batch_size 512 --num_gpus $SLURM_GPUS_PER_NODE

0 commit comments

Comments
 (0)