-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_narrow_cache.py
More file actions
139 lines (109 loc) · 4.41 KB
/
build_narrow_cache.py
File metadata and controls
139 lines (109 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Build two narrow-HU caches from the existing 1-channel cache.
HU mapping:
HU < 48 → 0.0
48 ≤ HU ≤ 90 → 0.05 + (HU - 48) / 42 * 0.90 (linear 0.05–0.95)
HU > 90 → 1.0
Reads from /home/johnb/cache_1ch (1-ch wide-HU cache).
Writes to:
TRAIN_DIR — train_ids + val_ids (from data_splits.json)
TEST_DIR — test_ids (from data_splits.json)
Each output .npz stores:
image_norm : float16 (512, 512) — narrow-HU mapped values in [0, 1]
"""
import os
import sys
import json
import argparse
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
CACHE_SRC = '/home/johnb/cache_1ch'
TRAIN_DIR = '/home/johnb/cache_narrow_train'
TEST_DIR = '/home/johnb/cache_narrow_test'
SPLITS_FILE = '/home/johnb/NewICH/checkpoints_1ch/data_splits.json'
HU_LOW = 48.0
HU_HIGH = 90.0
def recover_hu(image_norm: np.ndarray) -> np.ndarray:
# cache_1ch encoding: image_norm = (HU + 300) / 480 (build_1ch_cache.py)
return image_norm.astype(np.float32) * 480.0 - 300.0
def narrow_hu_map(hu: np.ndarray) -> np.ndarray:
# Fixed global window: HU 48–90 → [0.0, 1.0] (matches hu_windows.WINDOW_NARROW)
out = np.clip((hu - HU_LOW) / (HU_HIGH - HU_LOW), 0.0, 1.0)
return out.astype(np.float16)
def convert_one(args):
"""Worker function: (image_id, src_dir, dst_dir) → (image_id, ok, msg)."""
image_id, src_dir, dst_dir = args
src = os.path.join(src_dir, f"{image_id}.npz")
dst = os.path.join(dst_dir, f"{image_id}.npz")
if os.path.exists(dst):
return image_id, True, 'skip'
try:
data = np.load(src)
hu = recover_hu(data['image_norm'])
mapped = narrow_hu_map(hu)
np.savez_compressed(dst, image_norm=mapped)
return image_id, True, 'ok'
except Exception as e:
return image_id, False, str(e)
def process_batch(ids, src_dir, dst_dir, label, workers=12):
os.makedirs(dst_dir, exist_ok=True)
tasks = [(img_id, src_dir, dst_dir) for img_id in ids]
total = len(tasks)
done = 0
errors = 0
skipped = 0
t0 = time.perf_counter()
print(f"\n{label}: {total:,} files → {dst_dir}")
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(convert_one, t): t[0] for t in tasks}
for fut in as_completed(futures):
img_id, ok, msg = fut.result()
done += 1
if not ok:
errors += 1
print(f" ERROR {img_id}: {msg}")
elif msg == 'skip':
skipped += 1
if done % 10000 == 0:
elapsed = time.perf_counter() - t0
rate = done / elapsed
eta = (total - done) / rate
print(f" {label}: {done:,}/{total:,} "
f"({100*done/total:.1f}%) "
f"{rate:.0f} files/s ETA {eta/60:.1f} min",
flush=True)
elapsed = time.perf_counter() - t0
print(f" Done: {done:,} files in {elapsed/60:.1f} min "
f"({skipped:,} skipped, {errors:,} errors)")
return errors
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src', default=CACHE_SRC)
parser.add_argument('--train-dir', default=TRAIN_DIR)
parser.add_argument('--test-dir', default=TEST_DIR)
parser.add_argument('--splits-file', default=SPLITS_FILE)
parser.add_argument('--workers', type=int, default=12)
args = parser.parse_args()
print(f"Loading splits from {args.splits_file}")
with open(args.splits_file) as f:
splits = json.load(f)
train_ids = splits['train_ids']
val_ids = splits['val_ids']
test_ids = splits['test_ids']
# train_dir gets train+val (both needed for training runs)
trainval_ids = train_ids + val_ids
print(f" train+val : {len(trainval_ids):,}")
print(f" test : {len(test_ids):,}")
print(f" total : {len(trainval_ids) + len(test_ids):,}")
print(f" workers : {args.workers}")
t_start = time.perf_counter()
errs1 = process_batch(trainval_ids, args.src, args.train_dir,
'TRAIN+VAL', args.workers)
errs2 = process_batch(test_ids, args.src, args.test_dir,
'TEST', args.workers)
total_elapsed = time.perf_counter() - t_start
print(f"\nAll done in {total_elapsed/60:.1f} min "
f"(total errors: {errs1 + errs2})")
if __name__ == '__main__':
main()