Skip to content

Commit e26c370

Browse files
committed
update with precommit fix
1 parent 4c31229 commit e26c370

2 files changed

Lines changed: 39 additions & 36 deletions

File tree

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import os, sys
2-
import h5py
3-
import torch
41
import argparse
2+
import os
3+
from multiprocessing import Pool
4+
5+
import h5py
56
import numpy as np
6-
from torch import package
7+
import torch
78
from datasets import load_dataset
9+
from torch import package
810
from torchvision.transforms import CenterCrop, Compose, ToTensor
9-
from multiprocessing import Pool, current_process
1011
from tqdm import tqdm
1112

1213
# Set up dataset
@@ -23,7 +24,7 @@
2324
def decals_to_rgb(image, bands=["g", "r", "z"], scales=None, m=0.03, Q=20.0):
2425
axes, scales = zip(*[RGB_SCALES[bands[i]] for i in range(len(bands))])
2526
scales = [scales[i] for i in axes]
26-
image = image.movedim(1, -1).flip( -1)
27+
image = image.movedim(1, -1).flip(-1)
2728
scales = torch.tensor(scales, dtype=torch.float32).to(image.device)
2829
I = torch.sum(torch.clamp(image * scales + m, min=0), dim=-1) / len(bands)
2930
fI = torch.arcsinh(Q * I) / np.sqrt(Q)
@@ -33,7 +34,6 @@ def decals_to_rgb(image, bands=["g", "r", "z"], scales=None, m=0.03, Q=20.0):
3334
return image.movedim(-1, 1)
3435

3536

36-
3737
def import_package(path: str, device: str = "cpu") -> torch.nn.Module:
3838
"""Import a torch package from a given path"""
3939
importer = package.PackageImporter(path)
@@ -44,20 +44,22 @@ def import_package(path: str, device: str = "cpu") -> torch.nn.Module:
4444
def process_file(args) -> None:
4545
"""Process a single file in the dataset"""
4646
file, save_dir, batch_size, gpu_id = args
47-
file_path = os.path.join(dset_root, file, '001-of-001.hdf5')
47+
file_path = os.path.join(dset_root, file, "001-of-001.hdf5")
4848

4949
# Set the GPU device for this process
5050
torch.cuda.set_device(gpu_id)
5151

5252
# Load the model
53-
astrodino = import_package("/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt").to(torch.device(f'cuda:{gpu_id}'))
53+
astrodino = import_package(
54+
"/mnt/ceph/users/polymathic/astroclip/pretrained/astrodino.pt"
55+
).to(torch.device(f"cuda:{gpu_id}"))
5456

5557
embeddings = []
56-
with h5py.File(file_path, 'r') as f:
58+
with h5py.File(file_path, "r") as f:
5759
img_batch = []
58-
for img in tqdm(f['image_array']):
60+
for img in tqdm(f["image_array"]):
5961
# Convert to RGB
60-
img = crop(torch.tensor(img[[0,1,3]])) # get g,r,z
62+
img = crop(torch.tensor(img[[0, 1, 3]])) # get g,r,z
6163

6264
# Append to batch
6365
img_batch.append(img)
@@ -71,9 +73,9 @@ def process_file(args) -> None:
7173
im_batch = []
7274

7375
# Get ra, dec, obj_id
74-
ra = f['RA'][:]
75-
dec = f['DEC'][:]
76-
obj_id = f['object_id'][:]
76+
ra = f["RA"][:]
77+
dec = f["DEC"][:]
78+
obj_id = f["object_id"][:]
7779

7880
# Concatenate embeddings
7981
embeddings = np.concatenate(embeddings, axis=0)
@@ -83,39 +85,40 @@ def process_file(args) -> None:
8385
if not os.path.exists(save_dir):
8486
os.makedirs(save_dir)
8587

86-
save_path = os.path.join(save_dir, '001-of-001.hdf5')
87-
with h5py.File(save_path, 'w') as f:
88-
f.create_dataset('embeddings', data=embeddings)
89-
f.create_dataset('RA', data=ra)
90-
f.create_dataset('DEC', data=dec)
91-
f.create_dataset('object_id', data=obj_id)
88+
save_path = os.path.join(save_dir, "001-of-001.hdf5")
89+
with h5py.File(save_path, "w") as f:
90+
f.create_dataset("embeddings", data=embeddings)
91+
f.create_dataset("RA", data=ra)
92+
f.create_dataset("DEC", data=dec)
93+
f.create_dataset("object_id", data=obj_id)
9294

9395

9496
def embed_legacysurvey(
95-
dset_root: str,
96-
save_dir: str,
97-
astrodino_dir: str,
98-
batch_size=512,
99-
num_gpus=4
97+
dset_root: str, save_dir: str, astrodino_dir: str, batch_size=512, num_gpus=4
10098
):
10199
# List all files in the dataset
102100
files = os.listdir(dset_root)
103101

104102
# Create arguments for each process
105103
args = [(f, save_dir, batch_size, i % num_gpus) for i, f in enumerate(files)]
106-
104+
107105
# Use multiprocessing to process files in parallel
108106
with Pool(processes=num_gpus) as pool:
109107
pool.map(process_file, args)
110-
111-
if __name__ == '__main__':
108+
109+
110+
if __name__ == "__main__":
112111
parser = argparse.ArgumentParser()
113-
parser.add_argument('--dset_root', type=str, required=True)
114-
parser.add_argument('--save_dir', type=str, required=True)
115-
parser.add_argument('--astrodino_dir', type=str, default="/mnt/ceph/users/polymathic/astroclip/pretrained")
116-
parser.add_argument('--batch_size', type=int, default=512)
117-
parser.add_argument('--num_gpus', type=int, default=4)
112+
parser.add_argument("--dset_root", type=str, required=True)
113+
parser.add_argument("--save_dir", type=str, required=True)
114+
parser.add_argument(
115+
"--astrodino_dir",
116+
type=str,
117+
default="/mnt/ceph/users/polymathic/astroclip/pretrained",
118+
)
119+
parser.add_argument("--batch_size", type=int, default=512)
120+
parser.add_argument("--num_gpus", type=int, default=4)
118121
args = parser.parse_args()
119122

120123
# Run the embedding process
121-
embed_legacysurvey(dset_root, save_dir, astrodino_dir, batch_size, num_gpus)
124+
embed_legacysurvey(dset_root, save_dir, astrodino_dir, batch_size, num_gpus)

astroclip/astrodino/embed_legacysurvey/launch_embedding.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ export CUDA_LAUNCH_BLOCKING=1.
2323

2424
source /mnt/home/lparker/python_envs/toto/bin/activate
2525

26-
python launch_embeddings.py --dset_root $dset_root --save_root $save_root --batch_size 512 --num_gpus $SLURM_GPUS_PER_NODE
26+
python launch_embeddings.py --dset_root $dset_root --save_root $save_root --batch_size 512 --num_gpus $SLURM_GPUS_PER_NODE

0 commit comments

Comments
 (0)