Skip to content

Commit bd68e87

Browse files
committed
cleanup add to use a single dict to carry all the per axis info
1 parent 2ebde11 commit bd68e87

1 file changed

Lines changed: 118 additions & 104 deletions

File tree

src/pyuvdata/uvdata/uvdata.py

Lines changed: 118 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5502,61 +5502,85 @@ def __add__(
55025502
# Define parameters that must be the same to add objects
55035503
compatibility_params = ["_vis_units"]
55045504

5505-
axes = ["Nblts", "Nfreqs", "Npols"]
5506-
# axis_key_params defines which parameters to use as the defining
5507-
# parameters along each axis. These are used to identify overlapping data.
5508-
axis_key_params = {
5509-
"Nblts": ["time_array", "baseline_array"],
5510-
"Nfreqs": ["freq_array", "flex_spw_id_array"],
5511-
"Npols": ["polarization_array"],
5505+
# setup a dict to carry all the axis-specific info we need throughout
5506+
# the add process:
5507+
# - description is used in history string
5508+
# - key_params defines which parameters to use as the defining
5509+
# parameters along each axis. These are used to identify overlapping data.
5510+
# - key_func specifies a function to form a combined string if there are
5511+
# multiple key arrays (e.g. baseline-time, spw-freq)
5512+
# - reorder gives method & parameters for reording along each axis
5513+
# - order has info about how to sort each axis. Initialize to None
5514+
# for axes that are not added along (so do not need sorting),
5515+
# updated later.
5516+
# ---added later---
5517+
# - check_params gives parameters that should be checked if adding
5518+
# along other axes
5519+
# - key_arrays gives the arrays to use for checking for overlap per axis
5520+
# - overlap_inds has the outcomes of np.intersect1d on the key_arrays
5521+
# between this and other. So it has the both/this/other inds
5522+
# for any overlaps.
5523+
# - other_inds_use has the indices in other that will be added to this
5524+
# - combined_key_arrays has the final key arrays after adding.
5525+
# - t2o has the mapping of where arrays on other get mapped into this
5526+
# along each axis after padding
5527+
5528+
axis_info = {
5529+
"Nblts": {
5530+
"description": "baseline-time",
5531+
"key_params": ["time_array", "baseline_array"],
5532+
"key_func": "blt_str_arr",
5533+
"reorder": {"method": "reorder_blts", "parameter": "order"},
5534+
"order": None,
5535+
},
5536+
"Nfreqs": {
5537+
"description": "frequency",
5538+
"key_params": ["freq_array", "flex_spw_id_array"],
5539+
"key_func": "spw_freq_str_arr",
5540+
"reorder": {"method": "reorder_freqs", "parameter": "channel_order"},
5541+
"order": None,
5542+
},
5543+
"Npols": {
5544+
"description": "polarization",
5545+
"key_params": ["polarization_array"],
5546+
"reorder": {"method": "reorder_pols", "parameter": "order"},
5547+
"order": None,
5548+
},
55125549
}
5513-
# specify a function to form a combined string if there are multiple
5514-
# key arrays (e.g. baseline-time, spw-freq)
5515-
axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5516-
multi_axis_params = this._get_multi_axis_params()
5517-
# axis_parameters gives parameters whose form contains each axis
5518-
axis_parameters = {}
5519-
# axis_check_params gives parameters that should be checked if adding
5520-
# along other axes
5521-
axis_check_params = {}
5522-
# axis_key_arrays gives the arrays to use for checking for overlap per axis
5523-
axis_key_arrays = {}
5524-
# axis_overlap_inds has the outcomes of np.intersect1d on the
5525-
# axis_key_arrays per axis. So it has the both/this inds/other inds
5526-
# for any overlaps.
5527-
axis_overlap_inds = {}
5528-
for axis, overlap_params in axis_key_params.items():
5529-
axis_parameters[axis] = this._get_param_axis(axis)
5530-
axis_check_params[axis] = []
5531-
for param in axis_parameters[axis]:
5532-
# get parameters for compatibility checking. Exclude parameters
5533-
# that define overlap and multidimensional parameters which are
5534-
# handled separately later.
5535-
if (
5536-
param not in multi_axis_params
5537-
and param not in axis_key_params[axis]
5538-
):
5539-
axis_check_params[axis].append("_" + param)
5550+
5551+
for axis, info in axis_info.items():
5552+
# get parameters for compatibility checking. Exclude multidimensional
5553+
# parameters which are handled separately later.
5554+
params_this_axis = this._get_param_axis(axis, single_named_axis=True)
5555+
info["check_params"] = []
5556+
for param in params_this_axis:
5557+
# Also exclude parameters that define overlap
5558+
if param not in info["key_params"]:
5559+
info["check_params"].append("_" + param)
55405560

55415561
# build this/other arrays for checking for overlap.
5542-
if len(overlap_params) > 1:
5543-
axis_key_arrays[axis] = {
5544-
"this": getattr(this, axis_key_func[axis])(),
5545-
"other": getattr(other, axis_key_func[axis])(),
5562+
# key_arrays gives the arrays to use for checking for overlap per axis
5563+
if len(info["key_params"]) > 1:
5564+
info["key_arrays"] = {
5565+
"this": getattr(this, info["key_func"])(),
5566+
"other": getattr(other, info["key_func"])(),
55465567
}
55475568
else:
5548-
axis_key_arrays[axis] = {
5549-
"this": getattr(this, overlap_params[0]),
5550-
"other": getattr(other, overlap_params[0]),
5569+
info["key_arrays"] = {
5570+
"this": getattr(this, info["key_params"][0]),
5571+
"other": getattr(other, info["key_params"][0]),
55515572
}
55525573

55535574
# Check if we have overlapping data
55545575
both_inds, this_inds, other_inds = np.intersect1d(
5555-
axis_key_arrays[axis]["this"],
5556-
axis_key_arrays[axis]["other"],
5576+
info["key_arrays"]["this"],
5577+
info["key_arrays"]["other"],
55575578
return_indices=True,
55585579
)
5559-
axis_overlap_inds[axis] = {
5580+
# overlap_inds has the outcomes of np.intersect1d on the
5581+
# key_arrays per axis. So it has the both/this inds/other inds
5582+
# for any overlaps.
5583+
info["overlap_inds"] = {
55605584
"this": this_inds,
55615585
"other": other_inds,
55625586
"both": both_inds,
@@ -5565,11 +5589,12 @@ def __add__(
55655589
history_update_string = ""
55665590

55675591
if np.all(
5568-
[len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5592+
[len(axis_info[axis]["overlap_inds"]["both"]) > 0 for axis in axis_info]
55695593
):
55705594
# We have overlaps, check that overlapping data is not valid
55715595
this_test = []
55725596
other_test = []
5597+
multi_axis_params = this._get_multi_axis_params()
55735598
for param in multi_axis_params:
55745599
form = getattr(this, "_" + param).form
55755600
this_shape = getattr(this, param).shape
@@ -5583,12 +5608,12 @@ def __add__(
55835608
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
55845609
this_index_list.append(
55855610
np.expand_dims(
5586-
axis_overlap_inds[axis]["this"], axis=expand_axes
5611+
axis_info[axis]["overlap_inds"]["this"], axis=expand_axes
55875612
)
55885613
)
55895614
other_index_list.append(
55905615
np.expand_dims(
5591-
axis_overlap_inds[axis]["other"], axis=expand_axes
5616+
axis_info[axis]["overlap_inds"]["other"], axis=expand_axes
55925617
)
55935618
)
55945619
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
@@ -5622,41 +5647,35 @@ def __add__(
56225647
)
56235648

56245649
# Now actually find which axes are going to be added along
5625-
# other_inds_use have the indices in other that will be added to this
5626-
other_inds_use = {}
56275650
additions = []
5628-
axis_descriptions = {
5629-
"Nblts": "baseline-time",
5630-
"Nfreqs": "frequency",
5631-
"Npols": "polarization",
5632-
}
56335651
# find the indices in "other" but not in "this"
5634-
for axis in axes:
5652+
for axis, info in axis_info.items():
56355653
temp = np.nonzero(
5636-
~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"])
5654+
~np.isin(info["key_arrays"]["other"], info["key_arrays"]["this"])
56375655
)[0]
56385656
if len(temp) > 0:
5639-
other_inds_use[axis] = temp
5657+
# other_inds_use has the indices in other that will be added to this
5658+
info["other_inds_use"] = temp
56405659
# add params associated with the other axes to compatibility_params
5641-
for axis2 in axes:
5660+
for axis2 in axis_info:
56425661
if axis2 != axis:
5643-
compatibility_params.extend(axis_check_params[axis2])
5644-
additions.append(axis_descriptions[axis])
5662+
compatibility_params.extend(axis_info[axis2]["check_params"])
5663+
additions.append(info["description"])
56455664
else:
5646-
other_inds_use[axis] = []
5665+
info["other_inds_use"] = []
56475666

56485667
# Actually check compatibility parameters
56495668
for cp in compatibility_params:
56505669
params_match = None
5651-
for axis, check_list in axis_check_params.items():
5652-
if cp in check_list:
5670+
for axis, info in axis_info.items():
5671+
if cp in info["check_params"]:
56535672
# only check that overlapping indices match
56545673
this_param = getattr(this, cp)
56555674
this_param_overlap = this_param.get_from_form(
5656-
{axis: axis_overlap_inds[axis]["this"]}
5675+
{axis: info["overlap_inds"]["this"]}
56575676
)
56585677
other_param_overlap = getattr(other, cp).get_from_form(
5659-
{axis: axis_overlap_inds[axis]["other"]}
5678+
{axis: info["overlap_inds"]["other"]}
56605679
)
56615680
params_match = np.allclose(
56625681
this_param_overlap,
@@ -5683,85 +5702,80 @@ def __add__(
56835702

56845703
# Next, we want to make sure that the ordering of the _overlapping_ data is
56855704
# the same, so that things can get plugged together in a sensible way.
5686-
reorder_method = {
5687-
"Nblts": {"method": "reorder_blts", "parameter": "order"},
5688-
"Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
5689-
"Npols": {"method": "reorder_pols", "parameter": "order"},
5690-
}
5691-
for axis, ind_dict in axis_overlap_inds.items():
5692-
if len(ind_dict["this"]) != 0:
5705+
for axis, info in axis_info.items():
5706+
if len(info["overlap_inds"]["this"]) != 0:
56935707
# there is some overlap, so check sorting
5694-
this_argsort = np.argsort(ind_dict["this"])
5695-
other_argsort = np.argsort(ind_dict["other"])
5708+
this_argsort = np.argsort(info["overlap_inds"]["this"])
5709+
other_argsort = np.argsort(info["overlap_inds"]["other"])
56965710

56975711
if np.any(this_argsort != other_argsort):
56985712
temp_ind = np.arange(getattr(this, axis))
5699-
temp_ind[ind_dict["this"][this_argsort]] = temp_ind[
5700-
ind_dict["this"][other_argsort]
5713+
temp_ind[info["overlap_inds"]["this"][this_argsort]] = temp_ind[
5714+
info["overlap_inds"]["this"][other_argsort]
57015715
]
5702-
kwargs = {reorder_method[axis]["parameter"]: temp_ind}
5716+
kwargs = {info["reorder"]["parameter"]: temp_ind}
57035717

5704-
getattr(this, reorder_method[axis]["method"])(**kwargs)
5718+
getattr(this, info["reorder"]["method"])(**kwargs)
57055719

57065720
# checks are all done, start updating parameters
5707-
# combined_key_arrays has the final key arrays after adding.
5708-
combined_key_arrays = {}
5709-
# order_dict has info about how to sort each axis. Initialize to None
5710-
# for axes that are not added along (so do not need sorting)
5711-
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
5712-
for axis in axes:
5713-
if len(other_inds_use[axis]) > 0:
5714-
combined_key_arrays[axis] = np.concatenate(
5721+
for axis, info in axis_info.items():
5722+
if len(info["other_inds_use"]) > 0:
5723+
# combined_key_arrays has the final key arrays after adding.
5724+
info["combined_key_arrays"] = np.concatenate(
57155725
(
5716-
axis_key_arrays[axis]["this"],
5717-
axis_key_arrays[axis]["other"][other_inds_use[axis]],
5726+
info["key_arrays"]["this"],
5727+
info["key_arrays"]["other"][info["other_inds_use"]],
57185728
)
57195729
)
57205730
if axis == "Npols":
5721-
order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis]))
5731+
# order has info about how to sort each axis.
5732+
info["order"] = np.argsort(np.abs(info["combined_key_arrays"]))
57225733
elif axis == "Nfreqs" and (
57235734
np.any(np.diff(this.freq_array) < 0)
57245735
or np.any(np.diff(other.freq_array) < 0)
57255736
):
57265737
# deal with the possibility of spws with channels in
57275738
# descending order.
5728-
order_dict[axis] = utils.frequency._add_freq_order(
5739+
info["order"] = utils.frequency._add_freq_order(
57295740
np.concatenate(
57305741
(
57315742
this.flex_spw_id_array,
5732-
other.flex_spw_id_array[other_inds_use[axis]],
5743+
other.flex_spw_id_array[info["other_inds_use"]],
57335744
)
57345745
),
57355746
np.concatenate(
5736-
(this.freq_array, other.freq_array[other_inds_use[axis]])
5747+
(this.freq_array, other.freq_array[info["other_inds_use"]])
57375748
),
57385749
)
57395750
else:
5740-
order_dict[axis] = np.argsort(combined_key_arrays[axis])
5751+
info["order"] = np.argsort(info["combined_key_arrays"])
57415752

57425753
# first handle parameters with a single named axis
57435754
this._axis_add_helper(
5744-
other, axis, other_inds_use[axis], order_dict[axis]
5755+
other, axis, info["other_inds_use"], info["order"]
57455756
)
57465757

57475758
# then pad out parameters with multiple axes
5748-
this._axis_pad_helper(axis, len(other_inds_use[axis]))
5759+
this._axis_pad_helper(axis, len(info["other_inds_use"]))
57495760
else:
57505761
# no add along this axis, so it's the same as what's already on this
5751-
combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
5762+
info["combined_key_arrays"] = info["key_arrays"]["this"]
57525763

57535764
# Now fill in multidimensional parameters
5754-
# t2o_dict has the mapping of where arrays on other get mapped into
5765+
# t2o has the mapping of where arrays on other get mapped into
57555766
# this after padding
5756-
t2o_dict = {}
5757-
for axis, inds_dict in axis_key_arrays.items():
5758-
t2o_dict[axis] = np.nonzero(
5759-
np.isin(combined_key_arrays[axis], inds_dict["other"])
5767+
for _, info in axis_info.items():
5768+
info["t2o"] = np.nonzero(
5769+
np.isin(info["combined_key_arrays"], info["key_arrays"]["other"])
57605770
)[0]
57615771

5762-
this._fill_multi_helper(other, t2o_dict, order_dict)
5772+
this._fill_multi_helper(
5773+
other,
5774+
{axis: info["t2o"] for axis, info in axis_info.items()},
5775+
{axis: info["order"] for axis, info in axis_info.items()},
5776+
)
57635777

5764-
if len(other_inds_use["Nfreqs"]) > 0:
5778+
if len(axis_info["Nfreqs"]["other_inds_use"]) > 0:
57655779
# We want to preserve per-spw information based on first appearance
57665780
# in the concatenated array.
57675781
unique_index = np.sort(
@@ -5815,7 +5829,7 @@ def __add__(
58155829
)
58165830

58175831
# Reset blt_order if blt axis was added to
5818-
if order_dict["Nblts"] is not None:
5832+
if axis_info["Nblts"]["order"] is not None:
58195833
this.blt_order = ("time", "baseline")
58205834

58215835
this.set_rectangularity(force=True)

0 commit comments

Comments
 (0)