diff --git a/roman_imsim/skycat.py b/roman_imsim/skycat.py index 4b337518..207039da 100644 --- a/roman_imsim/skycat.py +++ b/roman_imsim/skycat.py @@ -11,6 +11,11 @@ RegisterObjectType, RegisterValueType, ) +from galsim.errors import galsim_warn + + +def no_lensing(self): + return 0.0, 0.0, 1.0 class SkyCatalogInterface: @@ -33,7 +38,9 @@ def __init__( ysize=None, obj_types=None, edge_pix=100, - max_flux=None, + chromaticity=True, + skycat_lensing=True, + galsim_shear=False, logger=None, ): """ @@ -57,6 +64,16 @@ def __init__( edge_pix : float [100] Size in pixels of the buffer region around nominal image to consider objects. + chromatic : bool [True] + Whether to use the chromatic GSObjects from skyCatalogs. If False, + don't use the SED information and uses the flux inegrated over the + bandpass instead. + skycat_lensing : bool [False] + If True, then do not apply lensing to the objects from SkyCatalogs. + galsim_shear : bool [False] + Whether a shear is specified in the config file. Only here for the + purpose of raising a warning if both skycat_lensing and galsim_shear + are applied. logger : logging.Logger [None] Logger object. """ @@ -82,12 +99,24 @@ def __init__( self.sca_center = wcs.toWorld(galsim.PositionD(self.xsize / 2.0, self.ysize / 2.0)) self._objects = None - # import os, psutil - # process = psutil.Process() - # print('skycat init',process.memory_info().rss) + self.chromaticity = chromaticity + + self._skycat_lensing = skycat_lensing + if self._skycat_lensing and galsim_shear: + galsim_warn( + "A shear is applied on top of the SkyCatalog shearing. It is " + "recommended to set skycat_lensing = False when applying an " + "external shear." + ) @property def objects(self): + + if not self._skycat_lensing: + from skycatalogs.objects.diffsky_object import DiffskyObject + + DiffskyObject.get_wl_params = no_lensing + from skycatalogs import skyCatalogs from skycatalogs import __version__ as skycatalogs_version from packaging.version import Version @@ -98,9 +127,6 @@ def objects(self): from skycatalogs.utils import PolygonalRegion if self._objects is None: - # import os, psutil - # process = psutil.Process() - # print('skycat obj 1',process.memory_info().rss) # Select objects from polygonal region bounded by CCD edges corners = ( (-self.edge_pix, -self.edge_pix), @@ -117,11 +143,31 @@ def objects(self): self._objects = sky_cat.get_objects_by_region(region, obj_type_set=self.obj_types, mjd=self.mjd) if not self._objects: self.logger.warning("No objects found on image.") - # import os, psutil - # process = psutil.Process() - # print('skycat obj 2',process.memory_info().rss) + else: + self._build_dtype_dict() return self._objects + def _build_dtype_dict(self): + self._dtype_dict = {} + obj_types = [] + for coll in self._objects.get_collections(): + objects_type = coll._object_type_unique + if objects_type in obj_types: + continue + col_names = list(coll.native_columns) + for col_name in col_names: + try: + # Some columns cannot be read in snana + np_type = coll.get_native_attribute(col_name).dtype.type() + except ValueError: + self.logger.warning(f"The column {col_name} could not be read from skyCatalog.") + continue + if np_type is None: + py_type = str + else: + py_type = type(np_type.astype(object)) + self._dtype_dict[col_name] = py_type + def get_sca_center(self): """ Return the SCA center. @@ -159,7 +205,107 @@ def getWorldPos(self, index): ra, dec = skycat_obj.ra, skycat_obj.dec return galsim.CelestialCoord(ra * galsim.degrees, dec * galsim.degrees) - def getObj(self, index, gsparams=None, rng=None, exptime=30): + def getFlux(self, index=None, skycat_obj=None, filter=None, mjd=None, exptime=None): + """ + Return the flux associated to an object. + + Parameters + ---------- + index : int + Index of the object in the self.objects catalog. Either index or + skycat_obj must be provided. [Default: None] + skycat_obj : skyCatalogs object + The skyCatalogs object for which the flux is computed. Either index + or skycat_obj must be provided. [Default: None] + filter : str, optional + Name of the filter for which the flux is computed. If None, use the + filter provided during initialization. [Default: None] + mjd : float, optional + Date of the observation in MJD format. If None, use the + mjd provided during initialization. [Default: None] + exptime : int or float, optional + Exposure time of the observation. If None, use the + exptime provided during initialization. [Default: None] + + Returns + ------- + flux + Computer flux at the given date for the requested exposure time and + filter. + """ + + if filter is None: + filter = self.bandpass.name + if mjd is None: + mjd = self.mjd + if exptime is None: + exptime = self.exptime + + if index is not None and skycat_obj is None: + skycat_obj = self.objects[index] + elif skycat_obj is not None and index is None: + pass + else: + raise ValueError("Either index or skycat_obj must be provided, but not both.") + + # We cache the SEDs for potential later use + if hasattr(skycat_obj, "get_wl_params"): + # _, _, mu = skycat_obj.get_wl_params() + gamma1 = skycat_obj.get_native_attribute("shear1") + gamma2 = skycat_obj.get_native_attribute("shear2") + kappa = skycat_obj.get_native_attribute("convergence") + mu = 1.0 / ((1.0 - kappa) ** 2 - (gamma1**2 + gamma2**2)) + else: + mu = 1.0 + + self._seds = skycat_obj.get_observer_sed_components(mjd=mjd) + fluxes = {} + for cmp_name, sed in self._seds.items(): + raw_flux = sed.calculateFlux(self.bandpass) + fluxes[cmp_name] = raw_flux * mu * exptime * roman.collecting_area + + return fluxes + + def getValue(self, index, field): + """ + Return a skyCatalog value for the an object. + + Parameters + ---------- + index : int + Index of the object in the self.objects catalog. + field : str + Name of the field for which you want the value. + + Returns + ------- + int or float or str or None + The value associated to the field or None if the field do not exist. + """ + + skycat_obj = self.objects[index] + + if field not in self._dtype_dict: + # We cannot raise an error because one could have a field for snana + # in the config and we don't want to crash because there are no SN + # in this particular image. We then default to False which might not + # be the right type for the required column but we have no way of knowing + # the correct type if the column do not exist. + self.logger.warning(f"The field {field} was not found in skyCatalog.") + return None + elif field not in skycat_obj.native_columns: + if self._dtype_dict[field] is int: + # There are no "special value" for integer so we default to hopefully something completely off + # np.nan is a float and None is a string, so we use -9999 for int + return -9999 + elif self._dtype_dict[field] is float: + return np.nan + elif self._dtype_dict[field] is str: + return None + else: + return skycat_obj.get_native_attribute(field) + + def getObj(self, index, gsparams=None, rng=None): """ Return the galsim object for the skyCatalog object corresponding to the specified index. If the skyCatalog @@ -183,35 +329,41 @@ def getObj(self, index, gsparams=None, rng=None, exptime=30): gsobjs = skycat_obj.get_gsobject_components(gsparams) # Compute the flux or get the cached value. - flux = ( - skycat_obj.get_roman_flux(self.bandpass.name, mjd=self.mjd) * self.exptime * roman.collecting_area - ) + fluxes = self.getFlux(skycat_obj=skycat_obj) + flux = sum(fluxes.values()) if np.isnan(flux): return None - # if True and skycat_obj.object_type == 'galaxy': - # # Apply DC2 dilation to the individual galaxy components. - # for component, gsobj in gsobjs.items(): - # comp = component if component != 'knots' else 'disk' - # a = skycat_obj.get_native_attribute(f'size_{comp}_true') - # b = skycat_obj.get_native_attribute(f'size_minor_{comp}_true') - # scale = np.sqrt(a/b) - # gsobjs[component] = gsobj.dilate(scale) - # Set up simple SED if too faint if flux < 40: faint = True - if not faint: - seds = skycat_obj.get_observer_sed_components(mjd=self.mjd) + + # This should catch both "star" and "gaia_star" objects + if "star" in skycat_obj.object_type: + # Cap (star) flux at 30M photons to avoid gross artifacts when trying + # to draw the Roman PSF in finite time and memory + flux_cap = 3e7 + if flux > flux_cap: + flux = flux_cap + fluxes["this_object"] = flux_cap + + if self.chromaticity: + if faint: + seds = {cmp_name: self._trivial_sed for cmp_name in gsobjs} + else: + seds = skycat_obj.get_observer_sed_components(mjd=self.mjd) + else: + seds = {cmp_name: 1.0 for cmp_name in gsobjs} gs_obj_list = [] for component in gsobjs: - if faint: - gsobjs[component] = gsobjs[component].evaluateAtWavelength(self.bandpass) - gs_obj_list.append(gsobjs[component] * self._trivial_sed) + # Give the object the right flux + gsobj = gsobjs[component] * seds[component] + if self.chromaticity: + gsobj = gsobj.withFlux(fluxes[component], self.bandpass) else: - if component in seds: - gs_obj_list.append(gsobjs[component] * seds[component]) + gsobj = gsobj.withFlux(fluxes[component]) + gs_obj_list.append(gsobj) if not gs_obj_list: return None @@ -221,27 +373,18 @@ def getObj(self, index, gsparams=None, rng=None, exptime=30): else: gs_object = galsim.Add(gs_obj_list) - # This should catch both "star" and "gaia_star" objects - if "star" in skycat_obj.object_type: - # Cap (star) flux at 30M photons to avoid gross artifacts when trying - # to draw the Roman PSF in finite time and memory - flux_cap = 3e7 - if flux > flux_cap: - flux = flux_cap - - # Give the object the right flux - gs_object = gs_object.withFlux(flux, self.bandpass) - gs_object.flux = flux + # gs_object = gs_object.withFlux(flux, self.bandpass) + if not hasattr(gs_object, "flux"): + gs_object.flux = flux - # Get the object type if (skycat_obj.object_type == "diffsky_galaxy") | (skycat_obj.object_type == "galaxy"): - gs_object.object_type = "galaxy" + object_type = "galaxy" if skycat_obj.object_type in {"star", "gaia_star"}: - gs_object.object_type = "star" + object_type = "star" if skycat_obj.object_type == "snana": - gs_object.object_type = "transient" + object_type = "transient" - return gs_object + return gs_object, object_type class SkyCatalogLoader(InputLoader): @@ -257,6 +400,8 @@ def getKwargs(self, config, base, logger): "mjd": float, "xsize": int, "ysize": int, + "skycat_lensing": bool, + "chromaticity": bool, } kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) wcs = galsim.config.BuildWCS(base["image"], "wcs", base, logger=logger) @@ -267,6 +412,7 @@ def getKwargs(self, config, base, logger): base["bandpass"] = galsim.config.BuildBandpass(base["image"], "bandpass", base, logger=logger)[0] kwargs["bandpass"] = base["bandpass"] + kwargs["galsim_shear"] = "shear" in base["gal"] # Sky catalog object lists are created per CCD, so they are # not safe to reuse. safe = False @@ -289,11 +435,11 @@ def SkyCatObj(config, base, ignore, gsparams, logger): message = ( "skyCatalogs selection and SCA center do not agree: \n" "skycat.sca_center: " - f"{sca_center.ra/galsim.degrees:.5f}, " - f"{sca_center.dec/galsim.degrees:.5f}\n" + f"{sca_center.ra / galsim.degrees:.5f}, " + f"{sca_center.dec / galsim.degrees:.5f}\n" "world_center: " - f"{world_center.ra/galsim.degrees:.5f}, " - f"{world_center.dec/galsim.degrees:.5f} \n" + f"{world_center.ra / galsim.degrees:.5f}, " + f"{world_center.dec / galsim.degrees:.5f} \n" f"Separation: {sep:.2e} arcsec" ) raise RuntimeError(message) @@ -306,13 +452,14 @@ def SkyCatObj(config, base, ignore, gsparams, logger): req = {"index": int} opt = {"num": int} - kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) + kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt, ignore=ignore) index = kwargs["index"] rng = galsim.config.GetRNG(config, base, logger, "SkyCatObj") - obj = skycat.getObj(index, gsparams=gsparams, rng=rng) + obj, object_type = skycat.getObj(index, gsparams=gsparams, rng=rng) base["object_id"] = skycat.objects[index].id + base["object_type"] = object_type return obj, safe @@ -336,10 +483,39 @@ def SkyCatWorldPos(config, base, value_type): return pos, safe +def SkyCatValue(config, base, value_type): + """Return a value from the object part of the skyCatalog""" + + skycat = galsim.config.GetInputObj("sky_catalog", config, base, "SkyCatValue") + + # Setup the indexing sequence if it hasn't been specified. The + # normal thing with a catalog is to just use each object in order, + # so we don't require the user to specify that by hand. We can do + # it for them. + galsim.config.SetDefaultIndex(config, skycat.getNObjects()) + + req = {"field": str, "index": int} + opt = {"obs_kind": str} + params, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) + field = params["field"] + index = params["index"] + + if field == "flux": + val = skycat.getFlux(index=index) + else: + val = skycat.getValue(index, field) + + return val, safe + + RegisterInputType("sky_catalog", SkyCatalogLoader(SkyCatalogInterface, has_nobj=True)) RegisterObjectType("SkyCatObj", SkyCatObj, input_type="sky_catalog") RegisterValueType("SkyCatWorldPos", SkyCatWorldPos, [galsim.CelestialCoord], input_type="sky_catalog") +# Here we have to provide None as a type otherwise Galsim complains but I don't know why.. +RegisterValueType("SkyCatValue", SkyCatValue, [float, int, str, None], input_type="sky_catalog") + + # This class was modified from https://github.com/LSSTDESC/imSim/. License info follows: # Copyright (c) 2016-2019, LSST Dark Energy Science Collaboration (DESC) diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index bc50d4c3..32b96263 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -53,7 +53,6 @@ def setup(self, config, base, xsize, ysize, ignore, logger): gal = galsim.config.BuildGSObject(base, "gal", logger=logger)[0] if gal is None: raise galsim.config.SkipThisObject("gal is None (invalid parameters)") - base["object_type"] = getattr(gal, "object_type", "") bandpass = base["bandpass"] if not hasattr(gal, "flux"): # In this case, the object flux has not been precomputed