diff --git a/roman_imsim/__init__.py b/roman_imsim/__init__.py index 72da6e7a..d3470020 100644 --- a/roman_imsim/__init__.py +++ b/roman_imsim/__init__.py @@ -16,3 +16,4 @@ from .skycat import * from .stamp import * from .wcs import * +from .resultants import * diff --git a/roman_imsim/resultants.py b/roman_imsim/resultants.py new file mode 100644 index 00000000..87318735 --- /dev/null +++ b/roman_imsim/resultants.py @@ -0,0 +1,85 @@ +import galsim +import galsim.config +import galsim.roman as roman +import yaml +from galsim.config import InputLoader, RegisterInputType, RegisterValueType + + +class ResultantDataLoader(object): + """Read the resultant information from the resultant strategy.""" + + _req_params = { + "file_name": str, + "strategy_name": str, + } + + def __init__(self, file_name, strategy_name, logger=None): + self.logger = galsim.config.LoggerWrapper(logger) + self.file_name = file_name + self.strategy_name = strategy_name + self.data = {} + + # try: + self.read_resultants() + # except: + # # Read visit info from the config file. + # self.logger.warning('Reading visit info from config file.') + + def read_resultants(self): + """Load the YAML file and get the requested strategy.""" + self.logger.info("Reading resultants from YAML file: %s", self.file_name) + try: + with open(self.file_name, "r") as f: + all_strategies = yaml.safe_load(f) + except Exception as e: + raise IOError(f"Could not read YAML file '{self.file_name}': {e}") + + if self.strategy_name not in all_strategies: + raise ValueError(f"Strategy '{self.strategy_name}' not found in YAML file.") + + strategy = all_strategies[self.strategy_name] + if not isinstance(strategy, list): + raise ValueError(f"Invalid strategy format for '{self.strategy_name}': must be a list of lists.") + + self.data["strategy"] = strategy + + def resultants_to_dt(self, config, base): + """Compute dt from list-of-lists.""" + strategy = self.data["strategy"] + if len(strategy) < 2: + raise ValueError("Need at least two resultants to compute dt.") + + avg_last = sum(strategy[-1]) / len(strategy[-1]) + avg_second = sum(strategy[0]) / len(strategy[0]) + + if "exptime" in config: + exptime = galsim.config.ParseValue(config, "exptime", base, float)[0] + else: + exptime = roman.exptime + + dt = exptime / (avg_last - avg_second) + return dt + + def get(self, field, default=None): + if field not in self.data and default is None: + raise KeyError(f"Field '{field}' not found in data.") + return self.data.get(field, default) + + +def ResultantData(config, base, value_type): + """Returns the resultant dt data.""" + rdata = galsim.config.GetInputObj("resultant_data", config, base, "ResultantDataLoader") + req = {"field": str} + kwargs, safe = galsim.config.GetAllParams(config, base, req=req) + field = kwargs["field"] + + if field == "dt": + val = rdata.resultants_to_dt(config, base) + else: + val = rdata.get(field) + + return value_type(val), safe + + +RegisterInputType("resultant_data", InputLoader(ResultantDataLoader, file_scope=True, takes_logger=True)) +RegisterValueType("ResultantData", ResultantData, [float, list], input_type="resultant_data") diff --git a/roman_imsim/sca.py b/roman_imsim/sca.py index 4d5dc15b..745ebdb8 100644 --- a/roman_imsim/sca.py +++ b/roman_imsim/sca.py @@ -1,3 +1,4 @@ +import gc import galsim import galsim.config import galsim.roman as roman @@ -308,5 +309,362 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger) image -= sky_image +class RomanSCAImageBuilderCMOS(ScatteredImageBuilder): + + def setup(self, config, base, image_num, obj_num, ignore, logger): + """Do the initialization and setup for building the image. + + This figures out the size that the image will be, but doesn't actually build it yet. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + ignore: A list of parameters that are allowed to be in config that we can + ignore here. i.e. it won't be an error if these parameters are present. + logger: If given, a logger object to log progress. + + Returns: + xsize, ysize + """ + # import os, psutil + # process = psutil.Process() + # print('sca setup 1',process.memory_info().rss) + logger.debug( + "image %d: Building RomanSCA: image, obj = %d,%d", + image_num, + image_num, + obj_num, + ) + + self.nobjects = self.getNObj(config, base, image_num, logger=logger) + logger.debug("image %d: nobj = %d", image_num, self.nobjects) + + # These are allowed for Scattered, but we don't use them here. + extra_ignore = [ + "image_pos", + "world_pos", + "stamp_size", + "stamp_xsize", + "stamp_ysize", + "nobjects", + ] + req = {"SCA": int, "filter": str, "mjd": float, "exptime": float} + opt = { + "draw_method": str, + "use_fft_bright": bool, + "stray_light": bool, + "thermal_background": bool, + "reciprocity_failure": bool, + "dark_current": bool, + "nonlinearity": bool, + "ipc": bool, + "read_noise": bool, + "sky_subtract": bool, + "ignore_noise": bool, + } + params = galsim.config.GetAllParams(config, base, req=req, opt=opt, ignore=ignore + extra_ignore)[0] + + self.sca = params["SCA"] + base["SCA"] = self.sca + self.filter = params["filter"] + self.mjd = params["mjd"] + self.exptime = params["exptime"] + + self.ignore_noise = params.get("ignore_noise", False) + # self.exptime = params.get('exptime', roman.exptime) # Default is roman standard exposure time. + self.stray_light = params.get("stray_light", False) + self.thermal_background = params.get("thermal_background", False) + self.reciprocity_failure = params.get("reciprocity_failure", False) + self.dark_current = params.get("dark_current", False) + self.nonlinearity = params.get("nonlinearity", False) + self.ipc = params.get("ipc", False) + self.read_noise = params.get("read_noise", False) + self.sky_subtract = params.get("sky_subtract", False) + + # If draw_method isn't in image field, it may be in stamp. Check. + self.draw_method = params.get("draw_method", base.get("stamp", {}).get("draw_method", "phot")) + + # pointing = CelestialCoord(ra=params['ra'], dec=params['dec']) + # wcs = roman.getWCS(world_pos = pointing, + # PA = params['pa']*galsim.degrees, + # date = params['date'], + # SCAs = self.sca, + # PA_is_FPA = True + # )[self.sca] + + # # GalSim expects a wcs in the image field. + # config['wcs'] = wcs + + # If user hasn't overridden the bandpass to use, get the standard one. + if "bandpass" not in config: + base["bandpass"] = galsim.config.BuildBandpass(base["image"], "bandpass", base, logger=logger) + + return roman.n_pix, roman.n_pix + + # def getBandpass(self, filter_name): + # if not hasattr(self, 'all_roman_bp'): + # self.all_roman_bp = roman.getBandpasses() + # return self.all_roman_bp[filter_name] + + def buildImage(self, config, base, image_num, obj_num, logger): + """Build an Image containing multiple objects placed at arbitrary locations. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + logger: If given, a logger object to log progress. + + Returns: + the final image and the current noise variance in the image as a tuple + """ + full_xsize = base["image_xsize"] + full_ysize = base["image_ysize"] + wcs = base["wcs"] + rdata = galsim.config.GetInputObj("resultant_data", config, base, "ResultantDataLoader") + strategy = rdata.get("strategy") + full_image = Image(full_xsize, full_ysize, dtype=float) + full_image.setOrigin(base["image_origin"]) + full_image.wcs = wcs + full_image.setZero() + + full_image.header = galsim.FitsHeader() + full_image.header["EXPTIME"] = self.exptime + full_image.header["MJD-OBS"] = self.mjd + full_image.header["DATE-OBS"] = Time(self.mjd, format="mjd").datetime.isoformat() + full_image.header["FILTER"] = self.filter + full_image.header["ZPTMAG"] = 2.5 * np.log10(self.exptime * roman.collecting_area) + + base["current_image"] = full_image + + if "image_pos" in config and "world_pos" in config: + raise galsim.GalSimConfigValueError( + "Both image_pos and world_pos specified for Scattered image.", + (config["image_pos"], config["world_pos"]), + ) + + if "image_pos" not in config and "world_pos" not in config: + xmin = base["image_origin"].x + xmax = xmin + full_xsize - 1 + ymin = base["image_origin"].y + ymax = ymin + full_ysize - 1 + config["image_pos"] = { + "type": "XY", + "x": {"type": "Random", "min": xmin, "max": xmax}, + "y": {"type": "Random", "min": ymin, "max": ymax}, + } + # Create index and lists for resultant management + max_dt = strategy[-1][-1] + resultant_i = 0 + resultant_buffer = [] + # Iterate through all dt + if "_global" not in base: + base["_global"] = {} + if "stamp_setup_cache" not in base["_global"]: + base["_global"]["stamp_setup_cache"] = {} + + for dt in np.arange(1, max_dt + 1): + + nbatch = self.nobjects // 1000 + 1 + full_array = galsim.PhotonArray(0) + + for batch in range(nbatch): + start_obj_num = self.nobjects * batch // nbatch + end_obj_num = self.nobjects * (batch + 1) // nbatch + nobj_batch = end_obj_num - start_obj_num + if nbatch > 1: + logger.warning( + "Start batch %d/%d with %d objects [%d, %d)", + batch + 1, + nbatch, + nobj_batch, + start_obj_num, + end_obj_num, + ) + stamps, current_vars = galsim.config.BuildStamps( + nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False + ) + logger.warning(base["_global"]["stamp_setup_cache"][batch]) + base["index_key"] = "image_num" + + for k in range(nobj_batch): + # This is our signal that the object was skipped. + if stamps[k] is None: + continue + bounds = full_image.bounds # stamps[k].bounds & + if not bounds.isDefined(): # pragma: no cover + # These noramlly show up as stamp==None, but technically it is possible + # to get a stamp that is off the main image, so check for that here to + # avoid an error. But this isn't covered in the imsim test suite. + continue + + # logger.debug("image %d: full bounds = %s", image_num, str(full_image.bounds)) + # logger.debug( + # "image %d: stamp %d bounds = %s", + # image_num, + # k + start_obj_num, + # str(stamps[k].bounds), + # ) + # logger.debug("image %d: Overlap = %s", image_num, str(bounds)) + # full_image[bounds] += stamps[k][bounds] + # logger.warning(stamps[k]) + full_array = galsim.PhotonArray.concatenate([*stamps, full_array]) + + stamps = None + + # # Bring the image so far up to a flat noise variance + # current_var = FlattenNoiseVariance( + # base, full_image, stamps, current_vars, logger) + # TODO : Apply BFE photon operation (uses current pre-read image and next photon array) + + # Turn full_image into running pre-read image + full_array.addTo(full_image) + del full_array + gc.collect() + # Decide what to do with readout based on resultant strategy + if ( + np.array([item for sub in strategy for item in sub]) == dt + ).any(): # does this dt exist in strategy + if (np.array(strategy[resultant_i]) == dt).any(): # is it in our current resultant + readout_im = full_image.copy() + readout_im = self.addNoiseToImage(readout_im, config, base, logger) + resultant_buffer.extend([readout_im]) + if len(resultant_buffer) > 1: + # combine readout images + resultant_buffer[0].array = resultant_buffer[0].array + resultant_buffer[1].array + del resultant_buffer[-1] + gc.collect() + + if np.array(strategy[resultant_i][-1]) == dt: # is this dt the last in our current resultant + divisor = len(np.array(strategy[resultant_i])) + # divide summed images by the length of resultant to get the average + # apply headers to the image array + # TODO:apply header to the image array + resultant_buffer[0].array = resultant_buffer[0].array / divisor + resultant_buffer[0].write("resultant_{0}.fits".format(resultant_i)) + resultant_i += 1 + resultant_buffer = [] + logger.warning("resultant{0} done".format(resultant_i)) + + # full_array.write("photonarray.fits") + # full_image.write("phot_image.fits") + + return full_image, None + + def addNoiseToImage(self, image, config, base, logger): + """Add the final noise to a Scattered image + + Parameters: + image: The image onto which to add the noise. + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + current_var: The current noise variance in each postage stamps. + logger: If given, a logger object to log progress. + """ + # check ignore noise + if self.ignore_noise: + return image + + base["current_noise_image"] = base["current_image"] + wcs = base["wcs"] + bp = base["bandpass"] + rng = galsim.config.GetRNG(config, base) + logger.info("image %d: Start RomanSCA detector effects", base.get("image_num", 0)) + + # Things that will eventually be subtracted (if sky_subtract) will have their expectation + # value added to sky_image. So technically, this includes things that aren't just sky. + # E.g. includes dark_current and thermal backgrounds. + sky_image = image.copy() + sky_level = roman.getSkyLevel(bp, world_pos=wcs.toWorld(image.true_center)) + logger.debug("Adding sky_level = %s", sky_level) + if self.stray_light: + logger.debug("Stray light fraction = %s", roman.stray_light_fraction) + sky_level *= 1.0 + roman.stray_light_fraction + wcs.makeSkyImage(sky_image, sky_level) + + # The other background is the expected thermal backgrounds in this band. + # These are provided in e-/pix/s, so we have to multiply by the exposure time. + if self.thermal_background: + tb = roman.thermal_backgrounds[self.filter] * self.exptime + logger.debug("Adding thermal background: %s", tb) + sky_image += roman.thermal_backgrounds[self.filter] * self.exptime + + # The image up to here is an expectation value. + # Realize it as an integer number of photons. + poisson_noise = galsim.noise.PoissonNoise(rng) + if self.draw_method == "phot": + logger.debug("Adding poisson noise to sky photons") + sky_image1 = sky_image.copy() + sky_image1.addNoise(poisson_noise) + image.quantize() # In case any profiles used InterpolatedImage, in which case + # the image won't necessarily be integers. + image += sky_image1 + else: + logger.debug("Adding poisson noise") + image += sky_image + image.addNoise(poisson_noise) + + # Apply the detector effects here. Not all of these are "noise" per se, but they + # happen interspersed with various noise effects, so apply them all in this step. + + # Note: according to Gregory Mosby & Bernard J. Rauscher, the following effects all + # happen "simultaneously" in the photo diodes: dark current, persistence, + # reciprocity failure (aka CRNL), burn in, and nonlinearity (aka CNL). + # Right now, we just do them in some order, but this could potentially be improved. + # The order we chose is historical, matching previous recommendations, but Mosby and + # Rauscher don't seem to think those recommendations are well-motivated. + + # TODO: Add burn-in and persistence here. + + if self.reciprocity_failure: + logger.debug("Applying reciprocity failure") + roman.addReciprocityFailure(image) + + if self.dark_current: + dc = roman.dark_current * self.exptime + logger.debug("Adding dark current: %s", dc) + sky_image += dc + dark_noise = galsim.noise.DeviateNoise(galsim.random.PoissonDeviate(rng, dc)) + image.addNoise(dark_noise) + + if self.nonlinearity: + logger.debug("Applying classical nonlinearity") + roman.applyNonlinearity(image) + + # Mosby and Rauscher say there are two read noises. One happens before IPC, the other + # one after. + # TODO: Add read_noise1 + if self.ipc: + logger.debug("Applying IPC") + roman.applyIPC(image) + + if self.read_noise: + logger.debug("Adding read noise %s", roman.read_noise) + image.addNoise(galsim.GaussianNoise(rng, sigma=roman.read_noise)) + + logger.debug("Applying gain %s", roman.gain) + image /= roman.gain + + # Make integer ADU now. + image.quantize() + + if self.sky_subtract: + logger.debug("Subtracting sky image") + sky_image /= roman.gain + sky_image.quantize() + image -= sky_image + return image + + def addNoise(self, image, config, base, image_num, obj_num, current_var, logger): + pass + + +# Register this as a valid type +RegisterImageType("roman_sca_cmos", RomanSCAImageBuilderCMOS()) # Register this as a valid type RegisterImageType("roman_sca", RomanSCAImageBuilder()) diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index bc50d4c3..dc7a1e6d 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -4,10 +4,9 @@ import numpy as np from galsim.config import RegisterStampType, StampBuilder + # import os, psutil # process = psutil.Process() - - class Roman_stamp(StampBuilder): """This performs the tasks necessary for building the stamp for a single object. @@ -361,5 +360,467 @@ def add_poisson_noise(self, fft_image): Roman_stamp.fix_seds = Roman_stamp._fix_seds_25 +class Roman_stamp_CMOS(StampBuilder): + """This performs the tasks necessary for building the stamp for a single object per dt. + + It uses the regular Basic functions for most things. + It specializes the quickSkip, buildProfile, and draw methods. + """ + + _trivial_sed = galsim.SED( + galsim.LookupTable([100, 2600], [1, 1], interpolant="linear"), + wave_type="nm", + flux_type="fphotons", + ) + + def setup(self, config, base, xsize, ysize, ignore, logger): + """ + Do the initialization and setup for building a postage stamp. + + In the base class, we check for and parse the appropriate size and position values in + config (aka base['stamp'] or base['image']. + + Values given in base['stamp'] take precedence if these are given in both places (which + would be confusing, so probably shouldn't do that, but there might be a use case where it + would make sense). + + Parameters: + config: The configuration dict for the stamp field. + base: The base configuration dict. + xsize: The xsize of the image to build (if known). + ysize: The ysize of the image to build (if known). + ignore: A list of parameters that are allowed to be in config that we can + ignore here. i.e. it won't be an error if these parameters are present. + logger: A logger object to log progress. + + Returns: + xsize, ysize, image_pos, world_pos + """ + # print('stamp setup',process.memory_info().rss) + # Handle the "use_fft_bright" parameter as it can be provided in either image or stamp config + if "use_fft_bright" in base["image"] and "use_fft_bright" not in config: + config["use_fft_bright"] = base["image"]["use_fft_bright"] + # Define the exp time for use in adjusting flux/photon for any dt + # slice and exposure time based on previously defined values + + if "exptime" in config: + self.exptime = galsim.config.ParseValue(config, "exptime", base, float)[0] + else: + self.exptime = roman.exptime + + dt = galsim.config.ParseValue(config, "dt", base, float)[0] + self.dt = dt + # create cache location + if "_global" not in base: + base["_global"] = {} + if "stamp_setup_cache" not in base["_global"]: + base["_global"]["stamp_setup_cache"] = {} + + stamp_cache = base["_global"]["stamp_setup_cache"] + # check cache and build + if base["obj_num"] in stamp_cache and stamp_cache[base["obj_num"]] is not None: + cached = stamp_cache[base["obj_num"]] + self.pupil_bin = cached["pupil_bin"] + self.flux = cached["flux"] + gal = cached["gal"] + gal.flux = cached["bflux"] + base["flux"] = gal.flux + base["mag"] = cached["bmag"] + world_pos = cached["world_pos"] + image_pos = cached["image_pos"] + + # Compute or retrieve the realized flux. + self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=cached["bflux"])() + base["realized_flux"] = self.realized_flux + base["pupil_bin"] = self.pupil_bin + base["object_type"] = getattr(gal, "object_type", "") + # Check if the realized flux is 0. + if self.realized_flux == 0: + # If so, we'll skip everything after this. + # The mechanism within GalSim to do this is to raise a special SkipThisObject class. + raise galsim.config.SkipThisObject("realized flux=0") + else: + objs = base["obj_num"] + gal = galsim.config.BuildGSObject(base, "gal", logger=logger)[0] + if gal is None: + raise galsim.config.SkipThisObject("gal is None (invalid parameters)") + base["object_type"] = getattr(gal, "object_type", "") + bandpass = base["bandpass"] + if not hasattr(gal, "flux"): + # In this case, the object flux has not been precomputed + # or cached by the skyCatalogs code. + gal.flux = gal.calculateFlux(bandpass) + + self.flux = gal.flux * ( + dt / self.exptime + ) # need to check if we change flux here or elsewhere before its passed to stamps + + # Cap (star) flux at 30M photons/(dt/exptime) to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 * (dt / self.exptime) + if self.flux > flux_cap: + if ( + hasattr(gal, "original") + and hasattr(gal.original, "original") + and isinstance(gal.original.original, galsim.DeltaFunction) + ) or (isinstance(gal, galsim.DeltaFunction)): + gal = gal.withFlux(flux_cap, bandpass) + self.flux = flux_cap + gal.flux = flux_cap + base["flux"] = ( + gal.flux + ) # need to check if we change flux here or elsewhere (this is currently the full exptime flux) + base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint + # print('stamp setup2',process.memory_info().rss) + self.bmag = base["mag"] + self.bflux = base["flux"] + # Compute or retrieve the realized flux. + + self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() + base["realized_flux"] = self.realized_flux + + # Check if the realized flux is 0. + if self.realized_flux == 0: + # If so, we'll skip everything after this. + # The mechanism within GalSim to do this is to raise a special SkipThisObject class. + raise galsim.config.SkipThisObject("realized flux=0") + + # Otherwise figure out the stamp size + if self.flux < 10 * (dt / self.exptime): + # For really faint things, don't try too hard. Just use 32x32. + image_size = 32 + self.pupil_bin = "achromatic" + + else: + gal_achrom = gal.evaluateAtWavelength(bandpass.effective_wavelength) + if hasattr(gal_achrom, "original") and isinstance(gal_achrom.original, galsim.DeltaFunction): + # For bright stars, set the following stamp size limits + if self.flux < 1e6 * (dt / self.exptime): + image_size = 500 + self.pupil_bin = 8 + elif self.flux < 6e6 * (dt / self.exptime): + image_size = 1000 + self.pupil_bin = 4 + else: + image_size = 1600 + self.pupil_bin = 2 + else: + self.pupil_bin = 8 + # # Get storead achromatic PSF + # psf = galsim.config.BuildGSObject(base, 'psf', logger=logger)[0]['achromatic'] + # obj = galsim.Convolve(gal_achrom, psf).withFlux(self.flux) + obj = gal_achrom.withGSParams(galsim.GSParams(stepk_minimum_hlr=20)) + image_size = obj.getGoodImageSize(roman.pixel_scale) + self.cached_gal = gal + # print('stamp setup3',process.memory_info().rss) + base["pupil_bin"] = self.pupil_bin + logger.info("Object flux is %d", self.flux) + logger.info("Object %d will use stamp size = %s", base.get("obj_num", 0), image_size) + + # Determine where this object is going to go: + # This is the same as what the base StampBuilder does: + if "image_pos" in config: + image_pos = galsim.config.ParseValue(config, "image_pos", base, galsim.PositionD)[0] + else: + image_pos = None + + if "world_pos" in config: + world_pos = galsim.config.ParseWorldPos(config, "world_pos", base, logger) + else: + world_pos = None + stamp_cache = base["_global"]["stamp_setup_cache"] + stamp_cache[objs] = { + "gal": self.cached_gal, + "image_size": image_size, + "pupil_bin": self.pupil_bin, + "flux": self.flux, + "bflux": self.bflux, + "bmag": self.bmag, + "image_pos": image_pos, + "world_pos": world_pos, + } + return image_size, image_size, image_pos, world_pos + + def buildPSF(self, config, base, gsparams, logger): + """Build the PSF object. + + For the Basic stamp type, this builds a PSF from the base['psf'] dict, if present, + else returns None. + + Parameters: + config: The configuration dict for the stamp field. + base: The base configuration dict. + gsparams: A dict of kwargs to use for a GSParams. More may be added to this + list by the galaxy object. + logger: A logger object to log progress. + + Returns: + the PSF + """ + if base.get("psf", {}).get("type", "roman_psf") != "roman_psf": + return galsim.config.BuildGSObject(base, "psf", gsparams=gsparams, logger=logger)[0] + + roman_psf = galsim.config.GetInputObj("roman_psf", config, base, "buildPSF") + psf = roman_psf.getPSF(self.pupil_bin, base["image_pos"]) + return psf + + def getDrawMethod(self, config, base, logger): + """Determine the draw method to use. + + @param config The configuration dict for the stamp field. + @param base The base configuration dict. + @param logger A logger object to log progress. + + @returns method + """ + method = galsim.config.ParseValue(config, "draw_method", base, str)[0] + self.use_fft_bright = False + if "use_fft_bright" in config: + self.use_fft_bright = galsim.config.ParseValue(config, "use_fft_bright", base, bool)[0] + + if method not in galsim.config.valid_draw_methods: + raise galsim.GalSimConfigValueError( + "Invalid draw_method.", method, galsim.config.valid_draw_methods + ) + + if method == "phot": + if self.pupil_bin in [4, 2] and self.use_fft_bright: + logger.info("Auto -> Use FFT drawing for object %d.", base["obj_num"]) + return "fft" + else: + logger.info("Auto -> Use photon shooting for object %d.", base["obj_num"]) + return "phot" + else: + # If user sets something specific for the method, rather than auto, + # then respect their wishes. + logger.info("Use specified method=%s for object %d.", method, base["obj_num"]) + return method + + @classmethod + def _fix_seds_24(cls, prof, bandpass): + # If any SEDs are not currently using a LookupTable for the function or if they are + # using spline interpolation, then the codepath is quite slow. + # Better to fix them before doing WavelengthSampler. + if isinstance(prof, galsim.ChromaticObject): + wave_list, _, _ = galsim.utilities.combine_wave_list(prof.sed, bandpass) + sed = prof.sed + # TODO: This bit should probably be ported back to Galsim. + # Something like sed.make_tabulated() + if not isinstance(sed._spec, galsim.LookupTable) or sed._spec.interpolant != "linear": + # Workaround for https://github.com/GalSim-developers/GalSim/issues/1228 + f = np.broadcast_to(sed(wave_list), wave_list.shape) + new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") + new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") + prof.sed = new_sed + + # Also recurse onto any components. + if hasattr(prof, "obj_list"): + for obj in prof.obj_list: + cls._fix_seds_24(obj, bandpass) + if hasattr(prof, "original"): + cls._fix_seds_24(prof.original, bandpass) + + @classmethod + def _fix_seds_25(cls, prof, bandpass): + # If any SEDs are not currently using a LookupTable for the function or if they are + # using spline interpolation, then the codepath is quite slow. + # Better to fix them before doing WavelengthSampler. + + # In GalSim 2.5, SEDs are not necessarily constructed in most chromatic objects. + # And really the only ones we need to worry about are the ones that come from + # SkyCatalog, since they might not have linear interpolants. + # Those objects are always SimpleChromaticTransformations. So only fix those. + if isinstance(prof, galsim.SimpleChromaticTransformation) and ( + not isinstance(prof._flux_ratio._spec, galsim.LookupTable) + or prof._flux_ratio._spec.interpolant != "linear" + ): + sed = prof._flux_ratio + wave_list, _, _ = galsim.utilities.combine_wave_list(sed, bandpass) + f = np.broadcast_to(sed(wave_list), wave_list.shape) + new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") + new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") + prof._flux_ratio = new_sed + + # Also recurse onto any components. + if isinstance(prof, galsim.ChromaticObject): + if hasattr(prof, "obj_list"): + for obj in prof.obj_list: + cls._fix_seds_25(obj, bandpass) + if hasattr(prof, "original"): + cls._fix_seds_25(prof.original, bandpass) + + def draw(self, prof, image, method, offset, config, base, logger): + """Draw the profile on the postage stamp for fft and convert to photonArray or create photonArray + + Parameters: + prof: The profile to draw. + image: The image onto which to draw the profile (which may be None). + method: The method to use to determine how the photonArray is made. + offset: The offset to apply when drawing. + config: The configuration dict for the stamp field. + base: The base configuration dict. + logger: A logger object to log progress. + + Returns: + the resulting image + """ + if prof is None: + # If was decide to do any rejection steps, this could be set to None, in which case, + # don't draw anything. + return galsim.PhotonArray(0) + + # Prof is normally a convolution here with obj_list being [gal, psf1, psf2,...] + # for some number of component PSFs. + # print('stamp draw',process.memory_info().rss) + + gal, *psfs = prof.obj_list if hasattr(prof, "obj_list") else [prof] + faint = self.flux < 40 * (self.dt / self.exptime) + bandpass = base["bandpass"] + if faint: + logger.info("Flux = %.0f Using trivial sed", self.flux) + else: + self.fix_seds(gal, bandpass) + + image.wcs = base["wcs"] + + # Set limit on the size of photons batches to consider when + # calling gsobject.drawImage. + maxN = int(1e6 * (self.dt / self.exptime)) + if "maxN" in config: + maxN = galsim.config.ParseValue(config, "maxN", base, int)[0] + # print('stamp draw2',process.memory_info().rss) + maxN = maxN + if method == "fft": + fft_image = image.copy() + fft_offset = offset + kwargs = dict( + method=method, + offset=fft_offset, + image=fft_image, + ) + # ignore photon ops at this stage since we will convert this image to a photon array + # if not faint and config.get("fft_photon_ops"): + # kwargs.update( + # { + # "photon_ops": galsim.config.BuildPhotonOps(config, "fft_photon_ops", base, logger), + # "maxN": maxN, + # "rng": self.rng, + # "n_subsample": 1, + # } + # ) + + # Go back to a combined convolution for fft drawing. + prof = galsim.Convolve([gal] + psfs) + try: + prof.drawImage(bandpass=bandpass, **kwargs) + except galsim.errors.GalSimFFTSizeError as e: + # I think this shouldn't happen with the updates I made to how the image size + # is calculated, even for extremely bright things. So it should be ok to + # just report what happened, give some extra information to diagonose the problem + # and raise the error. + logger.error("Caught error trying to draw using FFT:") + logger.error("%s", e) + logger.error("You may need to add a gsparams field with maximum_fft_size to") + logger.error("either the psf or gal field to allow larger FFTs.") + logger.info("prof = %r", prof) + logger.info("fft_image = %s", fft_image) + logger.info("offset = %r", offset) + logger.info("wcs = %r", image.wcs) + raise + + # Check if we need to add photon noise for bright objects drawn + # with FFT because we switched from phot to fft above. + if self.use_fft_bright: + self.add_poisson_noise(fft_image) + # In case we had to make a bigger image, just copy the part we need. + image += fft_image[image.bounds] + + # Some pixels can end up negative from FFT numerics. Just set them to 0. + fft_image.array[fft_image.array < 0] = 0.0 + fft_image.addNoise( + galsim.PoissonNoise(rng=self.rng) + ) # not sure if this should be kept before converting to photon array + # In case we had to make a bigger image, just copy the part we need. + image += fft_image[image.bounds] + photons = galsim.photonArray.makeFromImage(image, max_flux=1.0, rng=None) + + photons.x += image.center.x + photons.y += image.center.y + else: + # We already calculated realized_flux above. Use that now and tell GalSim not + # recalculate the Poisson realization of the flux. + # This shouldn't be necessary with the photon approach. + # gal = gal.withFlux(self.realized_flux, bandpass) + # print('stamp draw3b ',process.memory_info().rss) + + # ignore photon ops at this stage since we will convert this image to a phont array + # if not faint and "photon_ops" in config: + # photon_ops = galsim.config.BuildPhotonOps(config, "photon_ops", base, logger) + # else: + # photon_ops = [] + + # Put the psfs at the start of the photon_ops. + # Probably a little better to put them a bit later than the start in some cases + # (e.g. after TimeSampler, PupilAnnulusSampler), but leave that as a todo for now. + # photon_ops = psfs + photon_ops + + # prof = galsim.Convolve([gal] + psfs) + + # print('-------- gal ----------',gal) + # print('-------- psf ----------',psfs) + + # print('stamp draw3a',process.memory_info().rss) + # may want to just use makePhot() in the future + + # _, photons = gal.drawPhot( + # image, + # gain=1.0, + # add_to_image=False, + # n_photons=int(galsim.PoissonDeviate(mean = self.flux)()), + # rng=self.rng, + # max_extra_noise=0.0, + # poisson_flux=None, + # sensor=None, + # photon_ops=(), + # maxN=maxN, + # orig_center=offset, + # local_wcs=None, + # surface_ops=None) + n_phots = int(galsim.PoissonDeviate(mean=self.flux)()) + if n_phots != 0: + im = gal.drawImage( + bandpass, + method="phot", + offset=offset, + rng=self.rng, + n_photons=n_phots, + image=image, + photon_ops=(), + sensor=None, + add_to_image=False, + poisson_flux=False, + save_photons=True, + ) + photons = im.photons + photons.x += im.center.x # may be a better way + photons.y += im.center.y + photons.flux = 1 + else: + photons = galsim.PhotonArray(0) + # print('stamp draw3',process.memory_info().rss) + return photons + + +# Pick the right function to be _fix_seds. +if galsim.__version_info__ < (2, 5): + Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_24 +else: + Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_25 + + +# Register this as a valid type +RegisterStampType("Roman_stamp_CMOS", Roman_stamp_CMOS()) # Register this as a valid type RegisterStampType("Roman_stamp", Roman_stamp())