Skip to content

Commit b419936

Browse files
committed
Implemented the resultant image builder. resultant image headers, BFE, optimizations still need to be done
1 parent dfe6d90 commit b419936

4 files changed

Lines changed: 139 additions & 81 deletions

File tree

roman_imsim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .skycat import *
1717
from .stamp import *
1818
from .wcs import *
19+
from .resultants import *

roman_imsim/resultants.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import fitsio as fio
1+
import yaml
22
import galsim
33
import galsim.config
44
from galsim.config import InputLoader, RegisterInputType, RegisterValueType, RegisterObjectType
@@ -7,12 +7,13 @@
77
class ResultantDataLoader(object):
88
"""Read the resultant information from the resultant strategy."""
99

10-
_req_params = {"file_name": str, "strategy": str,}
10+
_req_params = {"file_name": str, "strategy_name": str,}
1111

12-
def __init__(self, file_name, strategy, logger=None):
12+
def __init__(self, file_name, strategy_name, logger=None):
1313
self.logger = galsim.config.LoggerWrapper(logger)
1414
self.file_name = file_name
15-
self.strategy = strategy
15+
self.strategy_name = strategy_name
16+
self.data = {}
1617

1718

1819
# try:
@@ -22,31 +23,43 @@ def __init__(self, file_name, strategy, logger=None):
2223
# self.logger.warning('Reading visit info from config file.')
2324

2425
def read_resultants(self):
25-
"""Read resultant info from the resultants file."""
26-
if self.file_name is None:
27-
raise ValueError("No resultants filename provided, trying to build from config information.")
28-
if self.strategy is None:
29-
raise ValueError("The strategy must be set when reading resultant strategy info from a resultants file.")
26+
"""Load the YAML file and get the requested strategy."""
27+
self.logger.info("Reading resultants from YAML file: %s", self.file_name)
28+
try:
29+
with open(self.file_name, "r") as f:
30+
all_strategies = yaml.safe_load(f)
31+
except Exception as e:
32+
raise IOError(f"Could not read YAML file '{self.file_name}': {e}")
3033

31-
self.logger.warning("Reading info from resultants file %s for strategy %s", self.file_name, self.strategy)
34+
if self.strategy_name not in all_strategies:
35+
raise ValueError(f"Strategy '{self.strategy_name}' not found in YAML file.")
3236

33-
data = fio.FITS(self.file_name)[-1][self.strategy]
37+
strategy = all_strategies[self.strategy_name]
38+
if not isinstance(strategy, list):
39+
raise ValueError(f"Invalid strategy format for '{self.strategy_name}': must be a list of lists.")
3440

35-
self.data = {}
36-
self.data["strategy"] = data["strategy"]
37-
self.data["dt"] = self.resultants_to_dt()
41+
self.data["strategy"] = strategy
42+
43+
def resultants_to_dt(self, config, base):
44+
"""Compute dt from list-of-lists."""
45+
strategy = self.data["strategy"]
46+
if len(strategy) < 2:
47+
raise ValueError("Need at least two resultants to compute dt.")
48+
49+
avg_last = sum(strategy[-1]) / len(strategy[-1])
50+
avg_second = sum(strategy[0]) / len(strategy[0])
3851

39-
def resultants_to_dt(self,config,base):
4052
if "exptime" in config:
4153
exptime = galsim.config.ParseValue(config, "exptime", base, float)[0]
4254
else:
4355
exptime = roman.exptime
44-
dt = exptime/((sum(self.strategy[-1])/len(self.strategy[-1]))-(sum(self.strategy[1])/len(self.strategy[1])))
56+
57+
dt = exptime / (avg_last - avg_second)
4558
return dt
46-
59+
4760
def get(self, field, default=None):
4861
if field not in self.data and default is None:
49-
raise KeyError("ResultantData field %s not present in data" % field)
62+
raise KeyError(f"Field '{field}' not found in data.")
5063
return self.data.get(field, default)
5164

5265
def ResultantData(config, base, value_type):
@@ -56,10 +69,14 @@ def ResultantData(config, base, value_type):
5669
kwargs, safe = galsim.config.GetAllParams(config, base, req=req)
5770
field = kwargs["field"]
5871

59-
val = value_type(rdata.get(field))
60-
return val, safe
72+
if field == "dt":
73+
val = rdata.resultants_to_dt(config, base)
74+
else:
75+
val = rdata.get(field)
76+
77+
return value_type(val), safe
6178

6279

6380
RegisterInputType("resultant_data", InputLoader(ResultantDataLoader, file_scope=True, takes_logger=True))
64-
RegisterValueType("ResultantData", ResultantData, [float,list], input_type="resultant_data")
81+
RegisterValueType("ResultantData", ResultantData, [float, list], input_type="resultant_data")
6582

roman_imsim/sca.py

Lines changed: 98 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from galsim.config import RegisterImageType
77
from galsim.config.image_scattered import ScatteredImageBuilder
88
from galsim.image import Image
9+
import gc
910

1011

11-
class RomanSCAImageBuilder(ScatteredImageBuilder):
12+
class RomanSCAImageBuilderCMOS(ScatteredImageBuilder):
1213

1314
def setup(self, config, base, image_num, obj_num, ignore, logger):
1415
"""Do the initialization and setup for building the image.
@@ -123,8 +124,8 @@ def buildImage(self, config, base, image_num, obj_num, logger):
123124
full_xsize = base["image_xsize"]
124125
full_ysize = base["image_ysize"]
125126
wcs = base["wcs"]
126-
127-
127+
rdata = galsim.config.GetInputObj("resultant_data", config, base, "ResultantDataLoader")
128+
strategy = rdata.get("strategy")
128129
full_image = Image(full_xsize, full_ysize, dtype=float)
129130
full_image.setOrigin(base["image_origin"])
130131
full_image.wcs = wcs
@@ -155,62 +156,97 @@ def buildImage(self, config, base, image_num, obj_num, logger):
155156
"x": {"type": "Random", "min": xmin, "max": xmax},
156157
"y": {"type": "Random", "min": ymin, "max": ymax},
157158
}
158-
159-
nbatch = self.nobjects // 1000 + 1
160-
full_array = galsim.PhotonArray(0)
161-
for batch in range(nbatch):
162-
start_obj_num = self.nobjects * batch // nbatch
163-
end_obj_num = self.nobjects * (batch + 1) // nbatch
164-
nobj_batch = end_obj_num - start_obj_num
165-
if nbatch > 1:
166-
logger.warning(
167-
"Start batch %d/%d with %d objects [%d, %d)",
168-
batch + 1,
169-
nbatch,
170-
nobj_batch,
171-
start_obj_num,
172-
end_obj_num,
159+
#Create index and lists for resultant management
160+
max_dt = strategy[-1][-1]
161+
resultant_i = 0
162+
resultant_buffer= []
163+
#Iterate through all dt
164+
for dt in np.arange(1,max_dt+1):
165+
166+
nbatch = self.nobjects // 1000 + 1
167+
full_array = galsim.PhotonArray(0)
168+
for batch in range(nbatch):
169+
start_obj_num = self.nobjects * batch // nbatch
170+
end_obj_num = self.nobjects * (batch + 1) // nbatch
171+
nobj_batch = end_obj_num - start_obj_num
172+
if nbatch > 1:
173+
logger.warning(
174+
"Start batch %d/%d with %d objects [%d, %d)",
175+
batch + 1,
176+
nbatch,
177+
nobj_batch,
178+
start_obj_num,
179+
end_obj_num,
180+
)
181+
stamps, current_vars = galsim.config.BuildStamps(
182+
nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False
173183
)
174-
stamps, current_vars = galsim.config.BuildStamps(
175-
nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False
176-
)
177-
base["index_key"] = "image_num"
178-
179-
for k in range(nobj_batch):
180-
# This is our signal that the object was skipped.
181-
if stamps[k] is None:
182-
continue
183-
bounds = full_image.bounds # stamps[k].bounds &
184-
if not bounds.isDefined(): # pragma: no cover
185-
# These noramlly show up as stamp==None, but technically it is possible
186-
# to get a stamp that is off the main image, so check for that here to
187-
# avoid an error. But this isn't covered in the imsim test suite.
188-
continue
189-
190-
# logger.debug("image %d: full bounds = %s", image_num, str(full_image.bounds))
191-
# logger.debug(
192-
# "image %d: stamp %d bounds = %s",
193-
# image_num,
194-
# k + start_obj_num,
195-
# str(stamps[k].bounds),
196-
# )
197-
# logger.debug("image %d: Overlap = %s", image_num, str(bounds))
198-
# full_image[bounds] += stamps[k][bounds]
199-
#logger.warning(stamps[k])
200-
full_array = galsim.PhotonArray.concatenate([*stamps,full_array])
201-
202-
stamps = None
203-
204-
# # Bring the image so far up to a flat noise variance
205-
# current_var = FlattenNoiseVariance(
206-
# base, full_image, stamps, current_vars, logger)
207-
208-
full_array.addTo(full_image)
209-
full_array.write("photonarray.fits")
210-
full_image.write("phot_image.fits")
184+
base["index_key"] = "image_num"
185+
186+
for k in range(nobj_batch):
187+
# This is our signal that the object was skipped.
188+
if stamps[k] is None:
189+
continue
190+
bounds = full_image.bounds # stamps[k].bounds &
191+
if not bounds.isDefined(): # pragma: no cover
192+
# These noramlly show up as stamp==None, but technically it is possible
193+
# to get a stamp that is off the main image, so check for that here to
194+
# avoid an error. But this isn't covered in the imsim test suite.
195+
continue
196+
197+
# logger.debug("image %d: full bounds = %s", image_num, str(full_image.bounds))
198+
# logger.debug(
199+
# "image %d: stamp %d bounds = %s",
200+
# image_num,
201+
# k + start_obj_num,
202+
# str(stamps[k].bounds),
203+
# )
204+
# logger.debug("image %d: Overlap = %s", image_num, str(bounds))
205+
# full_image[bounds] += stamps[k][bounds]
206+
#logger.warning(stamps[k])
207+
full_array = galsim.PhotonArray.concatenate([*stamps,full_array])
208+
209+
stamps = None
210+
211+
# # Bring the image so far up to a flat noise variance
212+
# current_var = FlattenNoiseVariance(
213+
# base, full_image, stamps, current_vars, logger)
214+
#TODO : Apply BFE photon operation (uses current pre-read image and next photon array)
215+
216+
# Turn full_image into running pre-read image
217+
full_array.addTo(full_image)
218+
del full_array
219+
gc.collect()
220+
# Decide what to do with readout based on resultant strategy
221+
if (np.array([item for sub in strategy for item in sub]) == dt).any():
222+
if (np.array(strategy[resultant_i]) == dt).any():
223+
readout_im = full_image.copy()
224+
readout_im = self.addNoiseToImage(readout_im, config, base, logger)
225+
resultant_buffer.extend([readout_im])
226+
if len(resultant_buffer)>1:
227+
#combine readout images
228+
resultant_buffer[0].array = resultant_buffer[0].array + resultant_buffer[1].array
229+
del resultant_buffer[-1]
230+
gc.collect()
231+
232+
if np.array(strategy[resultant_i][-1]) == dt:
233+
divisor = len(np.array(strategy[resultant_i]))
234+
#divide summed images by the length of resultant to get the average and apply headers to the image array
235+
#TODO:apply header to the image array
236+
resultant_buffer[0].array = resultant_buffer[0].array/divisor
237+
resultant_buffer[0].write('resultant_{0}.fits'.format(resultant_i))
238+
resultant_i+=1
239+
resultant_buffer = []
240+
logger.warning('resultant{0} done'.format(resultant_i))
241+
242+
243+
244+
# full_array.write("photonarray.fits")
245+
# full_image.write("phot_image.fits")
246+
211247
return full_image, None
212248

213-
def addNoise(self, image, config, base, image_num, obj_num, current_var, logger):
249+
def addNoiseToImage(self, image, config, base, logger):
214250
"""Add the final noise to a Scattered image
215251
216252
Parameters:
@@ -224,7 +260,7 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger)
224260
"""
225261
# check ignore noise
226262
if self.ignore_noise:
227-
return
263+
return image
228264

229265
base["current_noise_image"] = base["current_image"]
230266
wcs = base["wcs"]
@@ -314,7 +350,11 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger)
314350
sky_image /= roman.gain
315351
sky_image.quantize()
316352
image -= sky_image
353+
return image
354+
355+
def addNoise(self, image, config, base, image_num, obj_num, current_var, logger):
356+
pass
317357

318358

319359
# Register this as a valid type
320-
RegisterImageType("roman_sca", RomanSCAImageBuilder())
360+
RegisterImageType("roman_sca_cmos", RomanSCAImageBuilderCMOS())

roman_imsim/stamp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,9 @@ def add_poisson_noise(self, fft_image):
503503

504504
# Pick the right function to be _fix_seds.
505505
if galsim.__version_info__ < (2, 5):
506-
Roman_stamp.fix_seds = Roman_stamp._fix_seds_24
506+
Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_24
507507
else:
508-
Roman_stamp.fix_seds = Roman_stamp._fix_seds_25
508+
Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_25
509509

510510

511511
# Register this as a valid type

0 commit comments

Comments
 (0)