From f94c804c8d4c9bf192f29a03e7fc97aa168de806 Mon Sep 17 00:00:00 2001 From: Higgins00 Date: Tue, 22 Jul 2025 11:20:02 -0400 Subject: [PATCH 1/5] Revert "removing old yaml file that has paths to data that point to file locations that are dependent on system" This reverts commit a07b01bd4685339c83628a772402b867d886a955. --- sim.yaml | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 sim.yaml diff --git a/sim.yaml b/sim.yaml new file mode 100644 index 00000000..226ae6ce --- /dev/null +++ b/sim.yaml @@ -0,0 +1,86 @@ +# Input file for wfirst.py + +# Number of procs to limit star parallelisation to for memory reasons +starproc : 20 + +# Split over nodes/procs with MPI +mpi : True + +# If overwrite is False, the job will crash if the output directories already exist to safeguard against overwriting results. +overwrite : False + +# Draw and save full SCA images. In this mode, the isolated single-galaxy postage stamps will still be saved. +draw_sca : True + +# Limit which SCAs are drawn. Can be a single number or list. None or not defined will simulate all. +sca : 1 + +# initialise random seed +random_seed : 314 + +# output directory +out_path : /fs/scratch/cond0083/wfirst_sim_out/ +# output meds (and other) filename prefix +output_meds : 'test_new' +# Minimum postage stamp size (all stamps will be multiples of this) +stamp_size : 32 +# Number of stamp sizes +num_sizes : 9 +# nside of meds tiling (healpix pixels). Choose so that full meds can be held in memory of node. +nside : 256 + +# File containing dither information +dither_file : /users/PCON0003/cond0083/observing_sequence_hlsonly_5yr.fits + +# PSF properties +# To do a more exact calculation of the chromaticity and pupil plane configuration, set the `approximate_struts` and the `n_waves` keyword to defaults +approximate_struts : True # Approximate strut configuration +n_waves : 10 # Number of wavelengths used to create chromatic model of PSF +extra_aberrations : None # Include (additive) changes specification zernike parameters. None for default. +#los_motion : 0.015 # Include extra jitter in rms of arcsec. Ignored if not defined. +#los_motion_e1 : 0.3 # Shear to apply to jitter gaussian to simulate orientation-dependent rms +#los_motion_e2 : 0.0 +# Draw stars into SCA +draw_stars : True +# Catalog containing star positions and fluxes +star_sample : /fs/scratch/cond0083/gaia_stars.fits +# Write true psf to stamps for meds +draw_true_psf : True +# Oversampling factor for true psf stamps +oversample : 8 +# Size of true psf stamps in wfirst pixel units +psf_stampsize : 8 + +# Galaxy model info +# Distribution of objects in ra, dec +gal_dist : /fs/scratch/cond0083/radec_sub.fits +# Type of galaxy model: real cosmos objects (2), models from real cosmos objects (1), sersic disk (0) - Only 0 works now +gal_type : 0 +# Photometric properties to draw from. Must provide file if gal_type == 0 +gal_sample : /users/PCON0003/cond0083/Simulated_WFIRST+LSST_photometry_catalog_CANDELSbased.fits +# Type of sersic model to build. Disk, bulge, or composite (sum of disk + bulge + star-forming knots) +gal_model : composite +# Number of random, irregular star-forming knots to include in disk of object. +knots : 25 +# List of shears to select from +shear_list : [[0.2,0.0],[-0.2,0.0],[0.0,0.2],[0.0,-0.2]] +# SEDs for galaxy components +sedpath_E : /users/PCON0003/cond0083/GalSim/share/SEDs/NGC_4926_spec.dat +sedpath_Scd : /users/PCON0003/cond0083/GalSim/share/SEDs/NGC_4670_spec.dat +sedpath_Im : /users/PCON0003/cond0083/GalSim/share/SEDs/Mrk_33_spec.dat + +# Files for cosmos galaxy models +cat_name : real_galaxy_catalog_25.2.fits +cat_dir : /fs/scratch/cond0083/COSMOS_25.2_training_sample + +# Detector options +use_background : True # Adds/subtracts sky background to images +sub_true_background : True # Currently if False, subtracts background without including impact of dark current in the subtraction. +use_poisson_noise : True # Add poisson noise to images +use_recip_failure : True # Add reciprocity failure effect +use_dark_current : True # Add dark current +use_nonlinearity : True # Apply nonlinearity +use_interpix_cap : True # Add interpixel capacitance effect +use_read_noise : True # Add read noise +use_persistence : False # Currently not implemented. +save_diff : False # Save difference images during add_effects() - will do this for every call and overwrite the previous call's files, so be prepared to kill job otherwise will be very slow. \ No newline at end of file From d7e5b40cd4107a18b085270dc18a260cdc8b8fed Mon Sep 17 00:00:00 2001 From: Higgins00 Date: Thu, 26 Jun 2025 17:01:18 -0400 Subject: [PATCH 2/5] Added some questions/TODOs and setup the file to accept the previous arguments but instead have the draw function return the photon array Added makefromimage function to fft generated galaxies for their photonarray. General concept for a way to do the sampling, but all of this will likely need to be wrapped in its own method or Buildfromstamps will need to be modified to accept photonarray as an output. Removing file I previously made to test ways to do the photonarray combination changes to sca.py to just generate an image from photon arrays Stamp now is just for the photon shooting. The sca.py has been changed to conver a dt photon array into an image. Generate an image of 1/20th an exposure using full image photon arrays. Corrected the flux scaling issue for the photons resulting in a brighter image than expected, changed photon array concatenation as well. removed egg-info and __pycache__ added a input register and value register for the resultants strategy. Implemented the resultant image builder. resultant image headers, BFE, optimizations still need to be done Some comments Separate classes so it can be merged with main. fixing formating and preparing for merge into main --- roman_imsim/__init__.py | 1 + roman_imsim/resultants.py | 85 +++++++ roman_imsim/sca.py | 352 +++++++++++++++++++++++++++++ roman_imsim/stamp.py | 454 +++++++++++++++++++++++++++++++++++++- 4 files changed, 886 insertions(+), 6 deletions(-) create mode 100644 roman_imsim/resultants.py diff --git a/roman_imsim/__init__.py b/roman_imsim/__init__.py index 72da6e7a..d3470020 100644 --- a/roman_imsim/__init__.py +++ b/roman_imsim/__init__.py @@ -16,3 +16,4 @@ from .skycat import * from .stamp import * from .wcs import * +from .resultants import * diff --git a/roman_imsim/resultants.py b/roman_imsim/resultants.py new file mode 100644 index 00000000..87318735 --- /dev/null +++ b/roman_imsim/resultants.py @@ -0,0 +1,85 @@ +import galsim +import galsim.config +import galsim.roman as roman +import yaml +from galsim.config import InputLoader, RegisterInputType, RegisterValueType + + +class ResultantDataLoader(object): + """Read the resultant information from the resultant strategy.""" + + _req_params = { + "file_name": str, + "strategy_name": str, + } + + def __init__(self, file_name, strategy_name, logger=None): + self.logger = galsim.config.LoggerWrapper(logger) + self.file_name = file_name + self.strategy_name = strategy_name + self.data = {} + + # try: + self.read_resultants() + # except: + # # Read visit info from the config file. + # self.logger.warning('Reading visit info from config file.') + + def read_resultants(self): + """Load the YAML file and get the requested strategy.""" + self.logger.info("Reading resultants from YAML file: %s", self.file_name) + try: + with open(self.file_name, "r") as f: + all_strategies = yaml.safe_load(f) + except Exception as e: + raise IOError(f"Could not read YAML file '{self.file_name}': {e}") + + if self.strategy_name not in all_strategies: + raise ValueError(f"Strategy '{self.strategy_name}' not found in YAML file.") + + strategy = all_strategies[self.strategy_name] + if not isinstance(strategy, list): + raise ValueError(f"Invalid strategy format for '{self.strategy_name}': must be a list of lists.") + + self.data["strategy"] = strategy + + def resultants_to_dt(self, config, base): + """Compute dt from list-of-lists.""" + strategy = self.data["strategy"] + if len(strategy) < 2: + raise ValueError("Need at least two resultants to compute dt.") + + avg_last = sum(strategy[-1]) / len(strategy[-1]) + avg_second = sum(strategy[0]) / len(strategy[0]) + + if "exptime" in config: + exptime = galsim.config.ParseValue(config, "exptime", base, float)[0] + else: + exptime = roman.exptime + + dt = exptime / (avg_last - avg_second) + return dt + + def get(self, field, default=None): + if field not in self.data and default is None: + raise KeyError(f"Field '{field}' not found in data.") + return self.data.get(field, default) + + +def ResultantData(config, base, value_type): + """Returns the resultant dt data.""" + rdata = galsim.config.GetInputObj("resultant_data", config, base, "ResultantDataLoader") + req = {"field": str} + kwargs, safe = galsim.config.GetAllParams(config, base, req=req) + field = kwargs["field"] + + if field == "dt": + val = rdata.resultants_to_dt(config, base) + else: + val = rdata.get(field) + + return value_type(val), safe + + +RegisterInputType("resultant_data", InputLoader(ResultantDataLoader, file_scope=True, takes_logger=True)) +RegisterValueType("ResultantData", ResultantData, [float, list], input_type="resultant_data") diff --git a/roman_imsim/sca.py b/roman_imsim/sca.py index 4d5dc15b..4a2051fe 100644 --- a/roman_imsim/sca.py +++ b/roman_imsim/sca.py @@ -1,3 +1,5 @@ +import gc + import galsim import galsim.config import galsim.roman as roman @@ -308,5 +310,355 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger) image -= sky_image +class RomanSCAImageBuilderCMOS(ScatteredImageBuilder): + + def setup(self, config, base, image_num, obj_num, ignore, logger): + """Do the initialization and setup for building the image. + + This figures out the size that the image will be, but doesn't actually build it yet. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + ignore: A list of parameters that are allowed to be in config that we can + ignore here. i.e. it won't be an error if these parameters are present. + logger: If given, a logger object to log progress. + + Returns: + xsize, ysize + """ + # import os, psutil + # process = psutil.Process() + # print('sca setup 1',process.memory_info().rss) + logger.debug( + "image %d: Building RomanSCA: image, obj = %d,%d", + image_num, + image_num, + obj_num, + ) + + self.nobjects = self.getNObj(config, base, image_num, logger=logger) + logger.debug("image %d: nobj = %d", image_num, self.nobjects) + + # These are allowed for Scattered, but we don't use them here. + extra_ignore = [ + "image_pos", + "world_pos", + "stamp_size", + "stamp_xsize", + "stamp_ysize", + "nobjects", + ] + req = {"SCA": int, "filter": str, "mjd": float, "exptime": float} + opt = { + "draw_method": str, + "use_fft_bright": bool, + "stray_light": bool, + "thermal_background": bool, + "reciprocity_failure": bool, + "dark_current": bool, + "nonlinearity": bool, + "ipc": bool, + "read_noise": bool, + "sky_subtract": bool, + "ignore_noise": bool, + } + params = galsim.config.GetAllParams(config, base, req=req, opt=opt, ignore=ignore + extra_ignore)[0] + + self.sca = params["SCA"] + base["SCA"] = self.sca + self.filter = params["filter"] + self.mjd = params["mjd"] + self.exptime = params["exptime"] + + self.ignore_noise = params.get("ignore_noise", False) + # self.exptime = params.get('exptime', roman.exptime) # Default is roman standard exposure time. + self.stray_light = params.get("stray_light", False) + self.thermal_background = params.get("thermal_background", False) + self.reciprocity_failure = params.get("reciprocity_failure", False) + self.dark_current = params.get("dark_current", False) + self.nonlinearity = params.get("nonlinearity", False) + self.ipc = params.get("ipc", False) + self.read_noise = params.get("read_noise", False) + self.sky_subtract = params.get("sky_subtract", False) + + # If draw_method isn't in image field, it may be in stamp. Check. + self.draw_method = params.get("draw_method", base.get("stamp", {}).get("draw_method", "phot")) + + # pointing = CelestialCoord(ra=params['ra'], dec=params['dec']) + # wcs = roman.getWCS(world_pos = pointing, + # PA = params['pa']*galsim.degrees, + # date = params['date'], + # SCAs = self.sca, + # PA_is_FPA = True + # )[self.sca] + + # # GalSim expects a wcs in the image field. + # config['wcs'] = wcs + + # If user hasn't overridden the bandpass to use, get the standard one. + if "bandpass" not in config: + base["bandpass"] = galsim.config.BuildBandpass(base["image"], "bandpass", base, logger=logger) + + return roman.n_pix, roman.n_pix + + # def getBandpass(self, filter_name): + # if not hasattr(self, 'all_roman_bp'): + # self.all_roman_bp = roman.getBandpasses() + # return self.all_roman_bp[filter_name] + + def buildImage(self, config, base, image_num, obj_num, logger): + """Build an Image containing multiple objects placed at arbitrary locations. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + logger: If given, a logger object to log progress. + + Returns: + the final image and the current noise variance in the image as a tuple + """ + full_xsize = base["image_xsize"] + full_ysize = base["image_ysize"] + wcs = base["wcs"] + rdata = galsim.config.GetInputObj("resultant_data", config, base, "ResultantDataLoader") + strategy = rdata.get("strategy") + full_image = Image(full_xsize, full_ysize, dtype=float) + full_image.setOrigin(base["image_origin"]) + full_image.wcs = wcs + full_image.setZero() + + full_image.header = galsim.FitsHeader() + full_image.header["EXPTIME"] = self.exptime + full_image.header["MJD-OBS"] = self.mjd + full_image.header["DATE-OBS"] = Time(self.mjd, format="mjd").datetime.isoformat() + full_image.header["FILTER"] = self.filter + full_image.header["ZPTMAG"] = 2.5 * np.log10(self.exptime * roman.collecting_area) + + base["current_image"] = full_image + + if "image_pos" in config and "world_pos" in config: + raise galsim.GalSimConfigValueError( + "Both image_pos and world_pos specified for Scattered image.", + (config["image_pos"], config["world_pos"]), + ) + + if "image_pos" not in config and "world_pos" not in config: + xmin = base["image_origin"].x + xmax = xmin + full_xsize - 1 + ymin = base["image_origin"].y + ymax = ymin + full_ysize - 1 + config["image_pos"] = { + "type": "XY", + "x": {"type": "Random", "min": xmin, "max": xmax}, + "y": {"type": "Random", "min": ymin, "max": ymax}, + } + # Create index and lists for resultant management + max_dt = strategy[-1][-1] + resultant_i = 0 + resultant_buffer = [] + # Iterate through all dt + for dt in np.arange(1, max_dt + 1): + + nbatch = self.nobjects // 1000 + 1 + full_array = galsim.PhotonArray(0) + for batch in range(nbatch): + start_obj_num = self.nobjects * batch // nbatch + end_obj_num = self.nobjects * (batch + 1) // nbatch + nobj_batch = end_obj_num - start_obj_num + if nbatch > 1: + logger.warning( + "Start batch %d/%d with %d objects [%d, %d)", + batch + 1, + nbatch, + nobj_batch, + start_obj_num, + end_obj_num, + ) + stamps, current_vars = galsim.config.BuildStamps( + nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False + ) + base["index_key"] = "image_num" + + for k in range(nobj_batch): + # This is our signal that the object was skipped. + if stamps[k] is None: + continue + bounds = full_image.bounds # stamps[k].bounds & + if not bounds.isDefined(): # pragma: no cover + # These noramlly show up as stamp==None, but technically it is possible + # to get a stamp that is off the main image, so check for that here to + # avoid an error. But this isn't covered in the imsim test suite. + continue + + # logger.debug("image %d: full bounds = %s", image_num, str(full_image.bounds)) + # logger.debug( + # "image %d: stamp %d bounds = %s", + # image_num, + # k + start_obj_num, + # str(stamps[k].bounds), + # ) + # logger.debug("image %d: Overlap = %s", image_num, str(bounds)) + # full_image[bounds] += stamps[k][bounds] + # logger.warning(stamps[k]) + full_array = galsim.PhotonArray.concatenate([*stamps, full_array]) + + stamps = None + + # # Bring the image so far up to a flat noise variance + # current_var = FlattenNoiseVariance( + # base, full_image, stamps, current_vars, logger) + # TODO : Apply BFE photon operation (uses current pre-read image and next photon array) + + # Turn full_image into running pre-read image + full_array.addTo(full_image) + del full_array + gc.collect() + # Decide what to do with readout based on resultant strategy + if ( + np.array([item for sub in strategy for item in sub]) == dt + ).any(): # does this dt exist in strategy + if (np.array(strategy[resultant_i]) == dt).any(): # is it in our current resultant + readout_im = full_image.copy() + readout_im = self.addNoiseToImage(readout_im, config, base, logger) + resultant_buffer.extend([readout_im]) + if len(resultant_buffer) > 1: + # combine readout images + resultant_buffer[0].array = resultant_buffer[0].array + resultant_buffer[1].array + del resultant_buffer[-1] + gc.collect() + + if np.array(strategy[resultant_i][-1]) == dt: # is this dt the last in our current resultant + divisor = len(np.array(strategy[resultant_i])) + # divide summed images by the length of resultant to get the average + # apply headers to the image array + # TODO:apply header to the image array + resultant_buffer[0].array = resultant_buffer[0].array / divisor + resultant_buffer[0].write("resultant_{0}.fits".format(resultant_i)) + resultant_i += 1 + resultant_buffer = [] + logger.warning("resultant{0} done".format(resultant_i)) + + # full_array.write("photonarray.fits") + # full_image.write("phot_image.fits") + + return full_image, None + + def addNoiseToImage(self, image, config, base, logger): + """Add the final noise to a Scattered image + + Parameters: + image: The image onto which to add the noise. + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + current_var: The current noise variance in each postage stamps. + logger: If given, a logger object to log progress. + """ + # check ignore noise + if self.ignore_noise: + return image + + base["current_noise_image"] = base["current_image"] + wcs = base["wcs"] + bp = base["bandpass"] + rng = galsim.config.GetRNG(config, base) + logger.info("image %d: Start RomanSCA detector effects", base.get("image_num", 0)) + + # Things that will eventually be subtracted (if sky_subtract) will have their expectation + # value added to sky_image. So technically, this includes things that aren't just sky. + # E.g. includes dark_current and thermal backgrounds. + sky_image = image.copy() + sky_level = roman.getSkyLevel(bp, world_pos=wcs.toWorld(image.true_center)) + logger.debug("Adding sky_level = %s", sky_level) + if self.stray_light: + logger.debug("Stray light fraction = %s", roman.stray_light_fraction) + sky_level *= 1.0 + roman.stray_light_fraction + wcs.makeSkyImage(sky_image, sky_level) + + # The other background is the expected thermal backgrounds in this band. + # These are provided in e-/pix/s, so we have to multiply by the exposure time. + if self.thermal_background: + tb = roman.thermal_backgrounds[self.filter] * self.exptime + logger.debug("Adding thermal background: %s", tb) + sky_image += roman.thermal_backgrounds[self.filter] * self.exptime + + # The image up to here is an expectation value. + # Realize it as an integer number of photons. + poisson_noise = galsim.noise.PoissonNoise(rng) + if self.draw_method == "phot": + logger.debug("Adding poisson noise to sky photons") + sky_image1 = sky_image.copy() + sky_image1.addNoise(poisson_noise) + image.quantize() # In case any profiles used InterpolatedImage, in which case + # the image won't necessarily be integers. + image += sky_image1 + else: + logger.debug("Adding poisson noise") + image += sky_image + image.addNoise(poisson_noise) + + # Apply the detector effects here. Not all of these are "noise" per se, but they + # happen interspersed with various noise effects, so apply them all in this step. + + # Note: according to Gregory Mosby & Bernard J. Rauscher, the following effects all + # happen "simultaneously" in the photo diodes: dark current, persistence, + # reciprocity failure (aka CRNL), burn in, and nonlinearity (aka CNL). + # Right now, we just do them in some order, but this could potentially be improved. + # The order we chose is historical, matching previous recommendations, but Mosby and + # Rauscher don't seem to think those recommendations are well-motivated. + + # TODO: Add burn-in and persistence here. + + if self.reciprocity_failure: + logger.debug("Applying reciprocity failure") + roman.addReciprocityFailure(image) + + if self.dark_current: + dc = roman.dark_current * self.exptime + logger.debug("Adding dark current: %s", dc) + sky_image += dc + dark_noise = galsim.noise.DeviateNoise(galsim.random.PoissonDeviate(rng, dc)) + image.addNoise(dark_noise) + + if self.nonlinearity: + logger.debug("Applying classical nonlinearity") + roman.applyNonlinearity(image) + + # Mosby and Rauscher say there are two read noises. One happens before IPC, the other + # one after. + # TODO: Add read_noise1 + if self.ipc: + logger.debug("Applying IPC") + roman.applyIPC(image) + + if self.read_noise: + logger.debug("Adding read noise %s", roman.read_noise) + image.addNoise(galsim.GaussianNoise(rng, sigma=roman.read_noise)) + + logger.debug("Applying gain %s", roman.gain) + image /= roman.gain + + # Make integer ADU now. + image.quantize() + + if self.sky_subtract: + logger.debug("Subtracting sky image") + sky_image /= roman.gain + sky_image.quantize() + image -= sky_image + return image + + def addNoise(self, image, config, base, image_num, obj_num, current_var, logger): + pass + + +# Register this as a valid type +RegisterImageType("roman_sca_cmos", RomanSCAImageBuilderCMOS()) # Register this as a valid type RegisterImageType("roman_sca", RomanSCAImageBuilder()) diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index bc50d4c3..af718ec0 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -4,10 +4,9 @@ import numpy as np from galsim.config import RegisterStampType, StampBuilder + # import os, psutil # process = psutil.Process() - - class Roman_stamp(StampBuilder): """This performs the tasks necessary for building the stamp for a single object. @@ -60,6 +59,30 @@ def setup(self, config, base, xsize, ysize, ignore, logger): # or cached by the skyCatalogs code. gal.flux = gal.calculateFlux(bandpass) self.flux = gal.flux + # Cap (star) flux at 30M photons to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 + if self.flux > flux_cap: + if ( + hasattr(gal, "original") + and hasattr(gal.original, "original") + and isinstance(gal.original.original, galsim.DeltaFunction) + ) or (isinstance(gal, galsim.DeltaFunction)): + gal = gal.withFlux(flux_cap, bandpass) + self.flux = flux_cap + gal.flux = flux_cap + # Cap (star) flux at 30M photons to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 + if self.flux > flux_cap: + if ( + hasattr(gal, "original") + and hasattr(gal.original, "original") + and isinstance(gal.original.original, galsim.DeltaFunction) + ) or (isinstance(gal, galsim.DeltaFunction)): + gal = gal.withFlux(flux_cap, bandpass) + self.flux = flux_cap + gal.flux = flux_cap base["flux"] = gal.flux base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint # print('stamp setup2',process.memory_info().rss) @@ -182,8 +205,8 @@ def _fix_seds_24(cls, prof, bandpass): # using spline interpolation, then the codepath is quite slow. # Better to fix them before doing WavelengthSampler. if isinstance(prof, galsim.ChromaticObject): - wave_list, _, _ = galsim.utilities.combine_wave_list(prof.sed, bandpass) - sed = prof.sed + wave_list, _, _ = galsim.utilities.combine_wave_list(prof.SED, bandpass) + sed = prof.SED # TODO: This bit should probably be ported back to Galsim. # Something like sed.make_tabulated() if not isinstance(sed._spec, galsim.LookupTable) or sed._spec.interpolant != "linear": @@ -191,7 +214,7 @@ def _fix_seds_24(cls, prof, bandpass): f = np.broadcast_to(sed(wave_list), wave_list.shape) new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") - prof.sed = new_sed + prof.SED = new_sed # Also recurse onto any components. if hasattr(prof, "obj_list"): @@ -322,9 +345,10 @@ def draw(self, prof, image, method, offset, config, base, logger): ) # Go back to a combined convolution for fft drawing. + gal = gal.withFlux(self.flux, bandpass) prof = galsim.Convolve([gal] + psfs) try: - prof.drawImage(bandpass=bandpass, **kwargs) + prof.drawImage(bandpass, **kwargs) except galsim.errors.GalSimFFTSizeError as e: # I think this shouldn't happen with the updates I made to how the image size # is calculated, even for extremely bright things. So it should be ok to @@ -361,5 +385,423 @@ def add_poisson_noise(self, fft_image): Roman_stamp.fix_seds = Roman_stamp._fix_seds_25 +class Roman_stamp_CMOS(StampBuilder): + """This performs the tasks necessary for building the stamp for a single object per dt. + + It uses the regular Basic functions for most things. + It specializes the quickSkip, buildProfile, and draw methods. + """ + + _trivial_sed = galsim.SED( + galsim.LookupTable([100, 2600], [1, 1], interpolant="linear"), + wave_type="nm", + flux_type="fphotons", + ) + + def setup(self, config, base, xsize, ysize, ignore, logger): + """ + Do the initialization and setup for building a postage stamp. + + In the base class, we check for and parse the appropriate size and position values in + config (aka base['stamp'] or base['image']. + + Values given in base['stamp'] take precedence if these are given in both places (which + would be confusing, so probably shouldn't do that, but there might be a use case where it + would make sense). + + Parameters: + config: The configuration dict for the stamp field. + base: The base configuration dict. + xsize: The xsize of the image to build (if known). + ysize: The ysize of the image to build (if known). + ignore: A list of parameters that are allowed to be in config that we can + ignore here. i.e. it won't be an error if these parameters are present. + logger: A logger object to log progress. + + Returns: + xsize, ysize, image_pos, world_pos + """ + # print('stamp setup',process.memory_info().rss) + # Handle the "use_fft_bright" parameter as it can be provided in either image or stamp config + if "use_fft_bright" in base["image"] and "use_fft_bright" not in config: + config["use_fft_bright"] = base["image"]["use_fft_bright"] + # Define the exp time for use in adjusting flux/photon for any dt + # slice and exposure time based on previously defined values + + if "exptime" in config: + self.exptime = galsim.config.ParseValue(config, "exptime", base, float)[0] + else: + self.exptime = roman.exptime + + dt = galsim.config.ParseValue(config, "dt", base, float)[0] + gal = galsim.config.BuildGSObject(base, "gal", logger=logger)[0] + if gal is None: + raise galsim.config.SkipThisObject("gal is None (invalid parameters)") + base["object_type"] = getattr(gal, "object_type", "") + bandpass = base["bandpass"] + if not hasattr(gal, "flux"): + # In this case, the object flux has not been precomputed + # or cached by the skyCatalogs code. + gal.flux = gal.calculateFlux(bandpass) + + self.flux = gal.flux * ( + dt / self.exptime + ) # need to check if we change flux here or elsewhere before its passed to stamps + + # Cap (star) flux at 30M photons/(dt/exptime) to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 * (dt / self.exptime) + if self.flux > flux_cap: + if ( + hasattr(gal, "original") + and hasattr(gal.original, "original") + and isinstance(gal.original.original, galsim.DeltaFunction) + ) or (isinstance(gal, galsim.DeltaFunction)): + gal = gal.withFlux(flux_cap, bandpass) + self.flux = flux_cap + gal.flux = flux_cap + base["flux"] = ( + gal.flux + ) # need to check if we change flux here or elsewhere (this is currently the full exptime flux) + base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint + # print('stamp setup2',process.memory_info().rss) + + # Compute or retrieve the realized flux. + self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() + base["realized_flux"] = self.realized_flux + + # Check if the realized flux is 0. + if self.realized_flux == 0: + # If so, we'll skip everything after this. + # The mechanism within GalSim to do this is to raise a special SkipThisObject class. + raise galsim.config.SkipThisObject("realized flux=0") + + # Otherwise figure out the stamp size + if self.flux < 10 * (dt / self.exptime): + # For really faint things, don't try too hard. Just use 32x32. + image_size = 32 + self.pupil_bin = "achromatic" + + else: + gal_achrom = gal.evaluateAtWavelength(bandpass.effective_wavelength) + if hasattr(gal_achrom, "original") and isinstance(gal_achrom.original, galsim.DeltaFunction): + # For bright stars, set the following stamp size limits + if self.flux < 1e6 * (dt / self.exptime): + image_size = 500 + self.pupil_bin = 8 + elif self.flux < 6e6 * (dt / self.exptime): + image_size = 1000 + self.pupil_bin = 4 + else: + image_size = 1600 + self.pupil_bin = 2 + else: + self.pupil_bin = 8 + # # Get storead achromatic PSF + # psf = galsim.config.BuildGSObject(base, 'psf', logger=logger)[0]['achromatic'] + # obj = galsim.Convolve(gal_achrom, psf).withFlux(self.flux) + obj = gal_achrom.withGSParams(galsim.GSParams(stepk_minimum_hlr=20)) + image_size = obj.getGoodImageSize(roman.pixel_scale) + + # print('stamp setup3',process.memory_info().rss) + base["pupil_bin"] = self.pupil_bin + logger.info("Object flux is %d", self.flux) + logger.info("Object %d will use stamp size = %s", base.get("obj_num", 0), image_size) + + # Determine where this object is going to go: + # This is the same as what the base StampBuilder does: + if "image_pos" in config: + image_pos = galsim.config.ParseValue(config, "image_pos", base, galsim.PositionD)[0] + else: + image_pos = None + + if "world_pos" in config: + world_pos = galsim.config.ParseWorldPos(config, "world_pos", base, logger) + else: + world_pos = None + + return image_size, image_size, image_pos, world_pos + + def buildPSF(self, config, base, gsparams, logger): + """Build the PSF object. + + For the Basic stamp type, this builds a PSF from the base['psf'] dict, if present, + else returns None. + + Parameters: + config: The configuration dict for the stamp field. + base: The base configuration dict. + gsparams: A dict of kwargs to use for a GSParams. More may be added to this + list by the galaxy object. + logger: A logger object to log progress. + + Returns: + the PSF + """ + if base.get("psf", {}).get("type", "roman_psf") != "roman_psf": + return galsim.config.BuildGSObject(base, "psf", gsparams=gsparams, logger=logger)[0] + + roman_psf = galsim.config.GetInputObj("roman_psf", config, base, "buildPSF") + psf = roman_psf.getPSF(self.pupil_bin, base["image_pos"]) + return psf + + def getDrawMethod(self, config, base, logger): + """Determine the draw method to use. + + @param config The configuration dict for the stamp field. + @param base The base configuration dict. + @param logger A logger object to log progress. + + @returns method + """ + method = galsim.config.ParseValue(config, "draw_method", base, str)[0] + self.use_fft_bright = False + if "use_fft_bright" in config: + self.use_fft_bright = galsim.config.ParseValue(config, "use_fft_bright", base, bool)[0] + + if method not in galsim.config.valid_draw_methods: + raise galsim.GalSimConfigValueError( + "Invalid draw_method.", method, galsim.config.valid_draw_methods + ) + + if method == "phot": + if self.pupil_bin in [4, 2] and self.use_fft_bright: + logger.info("Auto -> Use FFT drawing for object %d.", base["obj_num"]) + return "fft" + else: + logger.info("Auto -> Use photon shooting for object %d.", base["obj_num"]) + return "phot" + else: + # If user sets something specific for the method, rather than auto, + # then respect their wishes. + logger.info("Use specified method=%s for object %d.", method, base["obj_num"]) + return method + + @classmethod + def _fix_seds_24(cls, prof, bandpass): + # If any SEDs are not currently using a LookupTable for the function or if they are + # using spline interpolation, then the codepath is quite slow. + # Better to fix them before doing WavelengthSampler. + if isinstance(prof, galsim.ChromaticObject): + wave_list, _, _ = galsim.utilities.combine_wave_list(prof.sed, bandpass) + sed = prof.sed + # TODO: This bit should probably be ported back to Galsim. + # Something like sed.make_tabulated() + if not isinstance(sed._spec, galsim.LookupTable) or sed._spec.interpolant != "linear": + # Workaround for https://github.com/GalSim-developers/GalSim/issues/1228 + f = np.broadcast_to(sed(wave_list), wave_list.shape) + new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") + new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") + prof.sed = new_sed + + # Also recurse onto any components. + if hasattr(prof, "obj_list"): + for obj in prof.obj_list: + cls._fix_seds_24(obj, bandpass) + if hasattr(prof, "original"): + cls._fix_seds_24(prof.original, bandpass) + + @classmethod + def _fix_seds_25(cls, prof, bandpass): + # If any SEDs are not currently using a LookupTable for the function or if they are + # using spline interpolation, then the codepath is quite slow. + # Better to fix them before doing WavelengthSampler. + + # In GalSim 2.5, SEDs are not necessarily constructed in most chromatic objects. + # And really the only ones we need to worry about are the ones that come from + # SkyCatalog, since they might not have linear interpolants. + # Those objects are always SimpleChromaticTransformations. So only fix those. + if isinstance(prof, galsim.SimpleChromaticTransformation) and ( + not isinstance(prof._flux_ratio._spec, galsim.LookupTable) + or prof._flux_ratio._spec.interpolant != "linear" + ): + sed = prof._flux_ratio + wave_list, _, _ = galsim.utilities.combine_wave_list(sed, bandpass) + f = np.broadcast_to(sed(wave_list), wave_list.shape) + new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") + new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") + prof._flux_ratio = new_sed + + # Also recurse onto any components. + if isinstance(prof, galsim.ChromaticObject): + if hasattr(prof, "obj_list"): + for obj in prof.obj_list: + cls._fix_seds_25(obj, bandpass) + if hasattr(prof, "original"): + cls._fix_seds_25(prof.original, bandpass) + + def draw(self, prof, image, method, offset, config, base, logger): + """Draw the profile on the postage stamp for fft and convert to photonArray or create photonArray + + Parameters: + prof: The profile to draw. + image: The image onto which to draw the profile (which may be None). + method: The method to use to determine how the photonArray is made. + offset: The offset to apply when drawing. + config: The configuration dict for the stamp field. + base: The base configuration dict. + logger: A logger object to log progress. + + Returns: + the resulting image + """ + if prof is None: + # If was decide to do any rejection steps, this could be set to None, in which case, + # don't draw anything. + return galsim.PhotonArray(0) + + # Prof is normally a convolution here with obj_list being [gal, psf1, psf2,...] + # for some number of component PSFs. + # print('stamp draw',process.memory_info().rss) + + gal, *psfs = prof.obj_list if hasattr(prof, "obj_list") else [prof] + dt = self.exptime / 20 + faint = self.flux < 40 * (dt / self.exptime) + bandpass = base["bandpass"] + if faint: + logger.info("Flux = %.0f Using trivial sed", self.flux) + else: + self.fix_seds(gal, bandpass) + + image.wcs = base["wcs"] + + # Set limit on the size of photons batches to consider when + # calling gsobject.drawImage. + maxN = int(1e6 * (dt / self.exptime)) + if "maxN" in config: + maxN = galsim.config.ParseValue(config, "maxN", base, int)[0] + # print('stamp draw2',process.memory_info().rss) + maxN = maxN + if method == "fft": + fft_image = image.copy() + fft_offset = offset + kwargs = dict( + method=method, + offset=fft_offset, + image=fft_image, + ) + # ignore photon ops at this stage since we will convert this image to a photon array + # if not faint and config.get("fft_photon_ops"): + # kwargs.update( + # { + # "photon_ops": galsim.config.BuildPhotonOps(config, "fft_photon_ops", base, logger), + # "maxN": maxN, + # "rng": self.rng, + # "n_subsample": 1, + # } + # ) + + # Go back to a combined convolution for fft drawing. + prof = galsim.Convolve([gal] + psfs) + try: + prof.drawImage(bandpass=bandpass, **kwargs) + except galsim.errors.GalSimFFTSizeError as e: + # I think this shouldn't happen with the updates I made to how the image size + # is calculated, even for extremely bright things. So it should be ok to + # just report what happened, give some extra information to diagonose the problem + # and raise the error. + logger.error("Caught error trying to draw using FFT:") + logger.error("%s", e) + logger.error("You may need to add a gsparams field with maximum_fft_size to") + logger.error("either the psf or gal field to allow larger FFTs.") + logger.info("prof = %r", prof) + logger.info("fft_image = %s", fft_image) + logger.info("offset = %r", offset) + logger.info("wcs = %r", image.wcs) + raise + + # Check if we need to add photon noise for bright objects drawn + # with FFT because we switched from phot to fft above. + if self.use_fft_bright: + self.add_poisson_noise(fft_image) + # In case we had to make a bigger image, just copy the part we need. + image += fft_image[image.bounds] + + # Some pixels can end up negative from FFT numerics. Just set them to 0. + fft_image.array[fft_image.array < 0] = 0.0 + fft_image.addNoise( + galsim.PoissonNoise(rng=self.rng) + ) # not sure if this should be kept before converting to photon array + # In case we had to make a bigger image, just copy the part we need. + image += fft_image[image.bounds] + photons = galsim.photonArray.makeFromImage(image, max_flux=1.0, rng=None) + + photons.x += image.center.x + photons.y += image.center.y + else: + # We already calculated realized_flux above. Use that now and tell GalSim not + # recalculate the Poisson realization of the flux. + # This shouldn't be necessary with the photon approach. + # gal = gal.withFlux(self.realized_flux, bandpass) + # print('stamp draw3b ',process.memory_info().rss) + + # ignore photon ops at this stage since we will convert this image to a phont array + # if not faint and "photon_ops" in config: + # photon_ops = galsim.config.BuildPhotonOps(config, "photon_ops", base, logger) + # else: + # photon_ops = [] + + # Put the psfs at the start of the photon_ops. + # Probably a little better to put them a bit later than the start in some cases + # (e.g. after TimeSampler, PupilAnnulusSampler), but leave that as a todo for now. + # photon_ops = psfs + photon_ops + + # prof = galsim.Convolve([gal] + psfs) + + # print('-------- gal ----------',gal) + # print('-------- psf ----------',psfs) + + # print('stamp draw3a',process.memory_info().rss) + # may want to just use makePhot() in the future + + # _, photons = gal.drawPhot( + # image, + # gain=1.0, + # add_to_image=False, + # n_photons=int(galsim.PoissonDeviate(mean = self.flux)()), + # rng=self.rng, + # max_extra_noise=0.0, + # poisson_flux=None, + # sensor=None, + # photon_ops=(), + # maxN=maxN, + # orig_center=offset, + # local_wcs=None, + # surface_ops=None) + n_phots = int(galsim.PoissonDeviate(mean=self.flux)()) + if n_phots != 0: + im = gal.drawImage( + bandpass, + method="phot", + offset=offset, + rng=self.rng, + n_photons=n_phots, + image=image, + photon_ops=(), + sensor=None, + add_to_image=False, + poisson_flux=False, + save_photons=True, + ) + photons = im.photons + photons.x += im.center.x # may be a better way + photons.y += im.center.y + photons.flux = 1 + else: + photons = galsim.PhotonArray(0) + # print('stamp draw3',process.memory_info().rss) + return photons + + +# Pick the right function to be _fix_seds. +if galsim.__version_info__ < (2, 5): + Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_24 +else: + Roman_stamp_CMOS.fix_seds = Roman_stamp_CMOS._fix_seds_25 + + +# Register this as a valid type +RegisterStampType("Roman_stamp_CMOS", Roman_stamp_CMOS()) # Register this as a valid type RegisterStampType("Roman_stamp", Roman_stamp()) From ff71bfa6e5dc1eca242339e1b5866a51a2832355 Mon Sep 17 00:00:00 2001 From: Higgins00 Date: Tue, 17 Feb 2026 13:28:28 -0500 Subject: [PATCH 3/5] making sure I dont change the current stamp.py --- roman_imsim/stamp.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index af718ec0..5cc50719 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -59,30 +59,6 @@ def setup(self, config, base, xsize, ysize, ignore, logger): # or cached by the skyCatalogs code. gal.flux = gal.calculateFlux(bandpass) self.flux = gal.flux - # Cap (star) flux at 30M photons to avoid gross artifacts when trying - # to draw the Roman PSF in finite time and memory - flux_cap = 3e7 - if self.flux > flux_cap: - if ( - hasattr(gal, "original") - and hasattr(gal.original, "original") - and isinstance(gal.original.original, galsim.DeltaFunction) - ) or (isinstance(gal, galsim.DeltaFunction)): - gal = gal.withFlux(flux_cap, bandpass) - self.flux = flux_cap - gal.flux = flux_cap - # Cap (star) flux at 30M photons to avoid gross artifacts when trying - # to draw the Roman PSF in finite time and memory - flux_cap = 3e7 - if self.flux > flux_cap: - if ( - hasattr(gal, "original") - and hasattr(gal.original, "original") - and isinstance(gal.original.original, galsim.DeltaFunction) - ) or (isinstance(gal, galsim.DeltaFunction)): - gal = gal.withFlux(flux_cap, bandpass) - self.flux = flux_cap - gal.flux = flux_cap base["flux"] = gal.flux base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint # print('stamp setup2',process.memory_info().rss) @@ -205,8 +181,8 @@ def _fix_seds_24(cls, prof, bandpass): # using spline interpolation, then the codepath is quite slow. # Better to fix them before doing WavelengthSampler. if isinstance(prof, galsim.ChromaticObject): - wave_list, _, _ = galsim.utilities.combine_wave_list(prof.SED, bandpass) - sed = prof.SED + wave_list, _, _ = galsim.utilities.combine_wave_list(prof.sed, bandpass) + sed = prof.sed # TODO: This bit should probably be ported back to Galsim. # Something like sed.make_tabulated() if not isinstance(sed._spec, galsim.LookupTable) or sed._spec.interpolant != "linear": @@ -214,7 +190,7 @@ def _fix_seds_24(cls, prof, bandpass): f = np.broadcast_to(sed(wave_list), wave_list.shape) new_spec = galsim.LookupTable(wave_list, f, interpolant="linear") new_sed = galsim.SED(new_spec, "nm", "fphotons" if sed.spectral else "1") - prof.SED = new_sed + prof.sed = new_sed # Also recurse onto any components. if hasattr(prof, "obj_list"): @@ -345,10 +321,9 @@ def draw(self, prof, image, method, offset, config, base, logger): ) # Go back to a combined convolution for fft drawing. - gal = gal.withFlux(self.flux, bandpass) prof = galsim.Convolve([gal] + psfs) try: - prof.drawImage(bandpass, **kwargs) + prof.drawImage(bandpass=bandpass, **kwargs) except galsim.errors.GalSimFFTSizeError as e: # I think this shouldn't happen with the updates I made to how the image size # is calculated, even for extremely bright things. So it should be ok to From ca087617d2fa7b985a3ac8f21085f5baf4ea1032 Mon Sep 17 00:00:00 2001 From: Higgins00 Date: Tue, 17 Feb 2026 13:52:37 -0500 Subject: [PATCH 4/5] removed sim.yaml --- sim.yaml | 86 -------------------------------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 sim.yaml diff --git a/sim.yaml b/sim.yaml deleted file mode 100644 index 226ae6ce..00000000 --- a/sim.yaml +++ /dev/null @@ -1,86 +0,0 @@ -# Input file for wfirst.py - -# Number of procs to limit star parallelisation to for memory reasons -starproc : 20 - -# Split over nodes/procs with MPI -mpi : True - -# If overwrite is False, the job will crash if the output directories already exist to safeguard against overwriting results. -overwrite : False - -# Draw and save full SCA images. In this mode, the isolated single-galaxy postage stamps will still be saved. -draw_sca : True - -# Limit which SCAs are drawn. Can be a single number or list. None or not defined will simulate all. -sca : 1 - -# initialise random seed -random_seed : 314 - -# output directory -out_path : /fs/scratch/cond0083/wfirst_sim_out/ -# output meds (and other) filename prefix -output_meds : 'test_new' -# Minimum postage stamp size (all stamps will be multiples of this) -stamp_size : 32 -# Number of stamp sizes -num_sizes : 9 -# nside of meds tiling (healpix pixels). Choose so that full meds can be held in memory of node. -nside : 256 - -# File containing dither information -dither_file : /users/PCON0003/cond0083/observing_sequence_hlsonly_5yr.fits - -# PSF properties -# To do a more exact calculation of the chromaticity and pupil plane configuration, set the `approximate_struts` and the `n_waves` keyword to defaults -approximate_struts : True # Approximate strut configuration -n_waves : 10 # Number of wavelengths used to create chromatic model of PSF -extra_aberrations : None # Include (additive) changes specification zernike parameters. None for default. -#los_motion : 0.015 # Include extra jitter in rms of arcsec. Ignored if not defined. -#los_motion_e1 : 0.3 # Shear to apply to jitter gaussian to simulate orientation-dependent rms -#los_motion_e2 : 0.0 -# Draw stars into SCA -draw_stars : True -# Catalog containing star positions and fluxes -star_sample : /fs/scratch/cond0083/gaia_stars.fits -# Write true psf to stamps for meds -draw_true_psf : True -# Oversampling factor for true psf stamps -oversample : 8 -# Size of true psf stamps in wfirst pixel units -psf_stampsize : 8 - -# Galaxy model info -# Distribution of objects in ra, dec -gal_dist : /fs/scratch/cond0083/radec_sub.fits -# Type of galaxy model: real cosmos objects (2), models from real cosmos objects (1), sersic disk (0) - Only 0 works now -gal_type : 0 -# Photometric properties to draw from. Must provide file if gal_type == 0 -gal_sample : /users/PCON0003/cond0083/Simulated_WFIRST+LSST_photometry_catalog_CANDELSbased.fits -# Type of sersic model to build. Disk, bulge, or composite (sum of disk + bulge + star-forming knots) -gal_model : composite -# Number of random, irregular star-forming knots to include in disk of object. -knots : 25 -# List of shears to select from -shear_list : [[0.2,0.0],[-0.2,0.0],[0.0,0.2],[0.0,-0.2]] -# SEDs for galaxy components -sedpath_E : /users/PCON0003/cond0083/GalSim/share/SEDs/NGC_4926_spec.dat -sedpath_Scd : /users/PCON0003/cond0083/GalSim/share/SEDs/NGC_4670_spec.dat -sedpath_Im : /users/PCON0003/cond0083/GalSim/share/SEDs/Mrk_33_spec.dat - -# Files for cosmos galaxy models -cat_name : real_galaxy_catalog_25.2.fits -cat_dir : /fs/scratch/cond0083/COSMOS_25.2_training_sample - -# Detector options -use_background : True # Adds/subtracts sky background to images -sub_true_background : True # Currently if False, subtracts background without including impact of dark current in the subtraction. -use_poisson_noise : True # Add poisson noise to images -use_recip_failure : True # Add reciprocity failure effect -use_dark_current : True # Add dark current -use_nonlinearity : True # Apply nonlinearity -use_interpix_cap : True # Add interpixel capacitance effect -use_read_noise : True # Add read noise -use_persistence : False # Currently not implemented. -save_diff : False # Save difference images during add_effects() - will do this for every call and overwrite the previous call's files, so be prepared to kill job otherwise will be very slow. \ No newline at end of file From 1db3eb32bb2b51f96941586bd8b8199989ddd39d Mon Sep 17 00:00:00 2001 From: Higgins00 Date: Thu, 19 Feb 2026 03:17:01 -0500 Subject: [PATCH 5/5] optimized for time. Roughly same run time as non photonarray way. 2GB memory required --- roman_imsim/sca.py | 8 +- roman_imsim/stamp.py | 216 ++++++++++++++++++++++++++----------------- 2 files changed, 137 insertions(+), 87 deletions(-) diff --git a/roman_imsim/sca.py b/roman_imsim/sca.py index 4a2051fe..745ebdb8 100644 --- a/roman_imsim/sca.py +++ b/roman_imsim/sca.py @@ -1,5 +1,4 @@ import gc - import galsim import galsim.config import galsim.roman as roman @@ -462,10 +461,16 @@ def buildImage(self, config, base, image_num, obj_num, logger): resultant_i = 0 resultant_buffer = [] # Iterate through all dt + if "_global" not in base: + base["_global"] = {} + if "stamp_setup_cache" not in base["_global"]: + base["_global"]["stamp_setup_cache"] = {} + for dt in np.arange(1, max_dt + 1): nbatch = self.nobjects // 1000 + 1 full_array = galsim.PhotonArray(0) + for batch in range(nbatch): start_obj_num = self.nobjects * batch // nbatch end_obj_num = self.nobjects * (batch + 1) // nbatch @@ -482,6 +487,7 @@ def buildImage(self, config, base, image_num, obj_num, logger): stamps, current_vars = galsim.config.BuildStamps( nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False ) + logger.warning(base["_global"]["stamp_setup_cache"][batch]) base["index_key"] = "image_num" for k in range(nobj_batch): diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index 5cc50719..dc7a1e6d 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -409,93 +409,138 @@ def setup(self, config, base, xsize, ysize, ignore, logger): self.exptime = roman.exptime dt = galsim.config.ParseValue(config, "dt", base, float)[0] - gal = galsim.config.BuildGSObject(base, "gal", logger=logger)[0] - if gal is None: - raise galsim.config.SkipThisObject("gal is None (invalid parameters)") - base["object_type"] = getattr(gal, "object_type", "") - bandpass = base["bandpass"] - if not hasattr(gal, "flux"): - # In this case, the object flux has not been precomputed - # or cached by the skyCatalogs code. - gal.flux = gal.calculateFlux(bandpass) - - self.flux = gal.flux * ( - dt / self.exptime - ) # need to check if we change flux here or elsewhere before its passed to stamps - - # Cap (star) flux at 30M photons/(dt/exptime) to avoid gross artifacts when trying - # to draw the Roman PSF in finite time and memory - flux_cap = 3e7 * (dt / self.exptime) - if self.flux > flux_cap: - if ( - hasattr(gal, "original") - and hasattr(gal.original, "original") - and isinstance(gal.original.original, galsim.DeltaFunction) - ) or (isinstance(gal, galsim.DeltaFunction)): - gal = gal.withFlux(flux_cap, bandpass) - self.flux = flux_cap - gal.flux = flux_cap - base["flux"] = ( - gal.flux - ) # need to check if we change flux here or elsewhere (this is currently the full exptime flux) - base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint - # print('stamp setup2',process.memory_info().rss) - - # Compute or retrieve the realized flux. - self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") - self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() - base["realized_flux"] = self.realized_flux - - # Check if the realized flux is 0. - if self.realized_flux == 0: - # If so, we'll skip everything after this. - # The mechanism within GalSim to do this is to raise a special SkipThisObject class. - raise galsim.config.SkipThisObject("realized flux=0") - - # Otherwise figure out the stamp size - if self.flux < 10 * (dt / self.exptime): - # For really faint things, don't try too hard. Just use 32x32. - image_size = 32 - self.pupil_bin = "achromatic" - + self.dt = dt + # create cache location + if "_global" not in base: + base["_global"] = {} + if "stamp_setup_cache" not in base["_global"]: + base["_global"]["stamp_setup_cache"] = {} + + stamp_cache = base["_global"]["stamp_setup_cache"] + # check cache and build + if base["obj_num"] in stamp_cache and stamp_cache[base["obj_num"]] is not None: + cached = stamp_cache[base["obj_num"]] + self.pupil_bin = cached["pupil_bin"] + self.flux = cached["flux"] + gal = cached["gal"] + gal.flux = cached["bflux"] + base["flux"] = gal.flux + base["mag"] = cached["bmag"] + world_pos = cached["world_pos"] + image_pos = cached["image_pos"] + + # Compute or retrieve the realized flux. + self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=cached["bflux"])() + base["realized_flux"] = self.realized_flux + base["pupil_bin"] = self.pupil_bin + base["object_type"] = getattr(gal, "object_type", "") + # Check if the realized flux is 0. + if self.realized_flux == 0: + # If so, we'll skip everything after this. + # The mechanism within GalSim to do this is to raise a special SkipThisObject class. + raise galsim.config.SkipThisObject("realized flux=0") else: - gal_achrom = gal.evaluateAtWavelength(bandpass.effective_wavelength) - if hasattr(gal_achrom, "original") and isinstance(gal_achrom.original, galsim.DeltaFunction): - # For bright stars, set the following stamp size limits - if self.flux < 1e6 * (dt / self.exptime): - image_size = 500 - self.pupil_bin = 8 - elif self.flux < 6e6 * (dt / self.exptime): - image_size = 1000 - self.pupil_bin = 4 + objs = base["obj_num"] + gal = galsim.config.BuildGSObject(base, "gal", logger=logger)[0] + if gal is None: + raise galsim.config.SkipThisObject("gal is None (invalid parameters)") + base["object_type"] = getattr(gal, "object_type", "") + bandpass = base["bandpass"] + if not hasattr(gal, "flux"): + # In this case, the object flux has not been precomputed + # or cached by the skyCatalogs code. + gal.flux = gal.calculateFlux(bandpass) + + self.flux = gal.flux * ( + dt / self.exptime + ) # need to check if we change flux here or elsewhere before its passed to stamps + + # Cap (star) flux at 30M photons/(dt/exptime) to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 * (dt / self.exptime) + if self.flux > flux_cap: + if ( + hasattr(gal, "original") + and hasattr(gal.original, "original") + and isinstance(gal.original.original, galsim.DeltaFunction) + ) or (isinstance(gal, galsim.DeltaFunction)): + gal = gal.withFlux(flux_cap, bandpass) + self.flux = flux_cap + gal.flux = flux_cap + base["flux"] = ( + gal.flux + ) # need to check if we change flux here or elsewhere (this is currently the full exptime flux) + base["mag"] = -2.5 * np.log10(gal.flux) + bandpass.zeropoint + # print('stamp setup2',process.memory_info().rss) + self.bmag = base["mag"] + self.bflux = base["flux"] + # Compute or retrieve the realized flux. + + self.rng = galsim.config.GetRNG(config, base, logger, "Roman_stamp") + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() + base["realized_flux"] = self.realized_flux + + # Check if the realized flux is 0. + if self.realized_flux == 0: + # If so, we'll skip everything after this. + # The mechanism within GalSim to do this is to raise a special SkipThisObject class. + raise galsim.config.SkipThisObject("realized flux=0") + + # Otherwise figure out the stamp size + if self.flux < 10 * (dt / self.exptime): + # For really faint things, don't try too hard. Just use 32x32. + image_size = 32 + self.pupil_bin = "achromatic" + + else: + gal_achrom = gal.evaluateAtWavelength(bandpass.effective_wavelength) + if hasattr(gal_achrom, "original") and isinstance(gal_achrom.original, galsim.DeltaFunction): + # For bright stars, set the following stamp size limits + if self.flux < 1e6 * (dt / self.exptime): + image_size = 500 + self.pupil_bin = 8 + elif self.flux < 6e6 * (dt / self.exptime): + image_size = 1000 + self.pupil_bin = 4 + else: + image_size = 1600 + self.pupil_bin = 2 else: - image_size = 1600 - self.pupil_bin = 2 + self.pupil_bin = 8 + # # Get storead achromatic PSF + # psf = galsim.config.BuildGSObject(base, 'psf', logger=logger)[0]['achromatic'] + # obj = galsim.Convolve(gal_achrom, psf).withFlux(self.flux) + obj = gal_achrom.withGSParams(galsim.GSParams(stepk_minimum_hlr=20)) + image_size = obj.getGoodImageSize(roman.pixel_scale) + self.cached_gal = gal + # print('stamp setup3',process.memory_info().rss) + base["pupil_bin"] = self.pupil_bin + logger.info("Object flux is %d", self.flux) + logger.info("Object %d will use stamp size = %s", base.get("obj_num", 0), image_size) + + # Determine where this object is going to go: + # This is the same as what the base StampBuilder does: + if "image_pos" in config: + image_pos = galsim.config.ParseValue(config, "image_pos", base, galsim.PositionD)[0] else: - self.pupil_bin = 8 - # # Get storead achromatic PSF - # psf = galsim.config.BuildGSObject(base, 'psf', logger=logger)[0]['achromatic'] - # obj = galsim.Convolve(gal_achrom, psf).withFlux(self.flux) - obj = gal_achrom.withGSParams(galsim.GSParams(stepk_minimum_hlr=20)) - image_size = obj.getGoodImageSize(roman.pixel_scale) - - # print('stamp setup3',process.memory_info().rss) - base["pupil_bin"] = self.pupil_bin - logger.info("Object flux is %d", self.flux) - logger.info("Object %d will use stamp size = %s", base.get("obj_num", 0), image_size) - - # Determine where this object is going to go: - # This is the same as what the base StampBuilder does: - if "image_pos" in config: - image_pos = galsim.config.ParseValue(config, "image_pos", base, galsim.PositionD)[0] - else: - image_pos = None - - if "world_pos" in config: - world_pos = galsim.config.ParseWorldPos(config, "world_pos", base, logger) - else: - world_pos = None + image_pos = None + if "world_pos" in config: + world_pos = galsim.config.ParseWorldPos(config, "world_pos", base, logger) + else: + world_pos = None + stamp_cache = base["_global"]["stamp_setup_cache"] + stamp_cache[objs] = { + "gal": self.cached_gal, + "image_size": image_size, + "pupil_bin": self.pupil_bin, + "flux": self.flux, + "bflux": self.bflux, + "bmag": self.bmag, + "image_pos": image_pos, + "world_pos": world_pos, + } return image_size, image_size, image_pos, world_pos def buildPSF(self, config, base, gsparams, logger): @@ -631,8 +676,7 @@ def draw(self, prof, image, method, offset, config, base, logger): # print('stamp draw',process.memory_info().rss) gal, *psfs = prof.obj_list if hasattr(prof, "obj_list") else [prof] - dt = self.exptime / 20 - faint = self.flux < 40 * (dt / self.exptime) + faint = self.flux < 40 * (self.dt / self.exptime) bandpass = base["bandpass"] if faint: logger.info("Flux = %.0f Using trivial sed", self.flux) @@ -643,7 +687,7 @@ def draw(self, prof, image, method, offset, config, base, logger): # Set limit on the size of photons batches to consider when # calling gsobject.drawImage. - maxN = int(1e6 * (dt / self.exptime)) + maxN = int(1e6 * (self.dt / self.exptime)) if "maxN" in config: maxN = galsim.config.ParseValue(config, "maxN", base, int)[0] # print('stamp draw2',process.memory_info().rss)