Skip to content

Commit 44f8692

Browse files
committed
Better comments and variable names
1 parent 32b4b1b commit 44f8692

2 files changed

Lines changed: 99 additions & 77 deletions

File tree

src/pyuvdata/uvbase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
776776
ret_dict = {}
777777
for param in self:
778778
# For each attribute, if the value is None, then bail, otherwise
779-
# attempt to figure out along which axis ind_arr will apply.
779+
# find the axis number(s) with the named shape.
780780

781781
attr = getattr(self, param)
782782
if (
@@ -791,8 +791,8 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
791791
continue
792792

793793
# Only look at where form is a tuple, since that's the only case we
794-
# can have a dynamically defined shape. Note that index doesn't work
795-
# here in the case of a repeated param_name in the form.
794+
# can have a dynamically defined shape. Handle a repeated
795+
# param_name in the form.
796796
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
797797
return ret_dict
798798

src/pyuvdata/uvdata/uvdata.py

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
197197
setattr(this, param, new_array)
198198

199199

200-
def _fill_multi_helper(
201-
this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
202-
):
200+
def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
203201
"""
204202
Fill UVParameter objects with multiple dimensions from the right side object.
205203

@@ -212,8 +210,6 @@ def _fill_multi_helper(
212210
t2o_dict : dict
213211
dict giving the indices in the left object to be filled from the right
214212
object for each axis (keys are axes, values are index arrays).
215-
sort_axes : list of str
216-
The axes that need to be sorted along.
217213
order_dict : dict
218214
dict giving the final sort indices for each axis (keys are axes, values
219215
are index arrays for sorting).
@@ -231,7 +227,7 @@ def _fill_multi_helper(
231227

232228
# Fix ordering
233229
for axis_ind, axis in enumerate(form):
234-
if axis in sort_axes:
230+
if order_dict[axis] is not None:
235231
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
236232
if np.array_equal(unique_order_diffs, np.array([1])):
237233
# everything is already in order
@@ -5711,59 +5707,72 @@ def __add__(
57115707
# Define parameters that must be the same to add objects
57125708
compatibility_params = ["_vis_units"]
57135709

5714-
# identify params that are not explicitly included in overlap calc per axis
57155710
axes = ["Nblts", "Nfreqs", "Npols"]
5716-
axis_params_check = {}
5717-
axis_overlap_params = {
5711+
# axis_key_params defines which parameters to use as the defining
5712+
# parameters along each axis. These are used to identify overlapping data.
5713+
axis_key_params = {
57185714
"Nblts": ["time_array", "baseline_array"],
57195715
"Nfreqs": ["freq_array", "flex_spw_id_array"],
57205716
"Npols": ["polarization_array"],
57215717
}
5722-
axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5723-
axis_dict = {}
5724-
for axis in axes:
5725-
axis_dict[axis] = this._get_param_axis(axis)
5726-
axis_params_check[axis] = []
5727-
for param in axis_dict[axis]:
5718+
# specify a function to form a combined string if there are multiple
5719+
# key arrays (e.g. baseline-time, spw-freq)
5720+
axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5721+
multi_axis_params = this._get_multi_axis_params()
5722+
# axis_parameters gives parameters whose form contains each axis
5723+
axis_parameters = {}
5724+
# axis_check_params gives parameters that should be checked if adding
5725+
# along other axes
5726+
axis_check_params = {}
5727+
# axis_key_arrays gives the arrays to use for checking for overlap per axis
5728+
axis_key_arrays = {}
5729+
# axis_overlap_inds has the outcomes of np.intersect1d on the
5730+
# axis_key_arrays per axis. So it has the both/this inds/other inds
5731+
# for any overlaps.
5732+
axis_overlap_inds = {}
5733+
for axis, overlap_params in axis_key_params.items():
5734+
axis_parameters[axis] = this._get_param_axis(axis)
5735+
axis_check_params[axis] = []
5736+
for param in axis_parameters[axis]:
5737+
# get parameters for compatibility checking. Exclude parameters
5738+
# that define overlap and multidimensional parameters which are
5739+
# handled separately later.
57285740
if (
5729-
param not in this._data_params
5730-
and param not in axis_overlap_params[axis]
5741+
param not in multi_axis_params
5742+
and param not in axis_key_params[axis]
57315743
):
5732-
axis_params_check[axis].append("_" + param)
5744+
axis_check_params[axis].append("_" + param)
57335745

5734-
# build this/other arrays for checking for overlap.
5735-
# Use a combined string if there are multiple arrays defining overlap
5736-
# (e.g. baseline-time, spw-freq)
5737-
axis_vals = {}
5738-
for axis, overlap_params in axis_overlap_params.items():
5746+
# build this/other arrays for checking for overlap.
57395747
if len(overlap_params) > 1:
5740-
axis_vals[axis] = {
5741-
"this": getattr(this, axis_combined_func[axis])(),
5742-
"other": getattr(other, axis_combined_func[axis])(),
5748+
axis_key_arrays[axis] = {
5749+
"this": getattr(this, axis_key_func[axis])(),
5750+
"other": getattr(other, axis_key_func[axis])(),
57435751
}
57445752
else:
5745-
axis_vals[axis] = {
5753+
axis_key_arrays[axis] = {
57465754
"this": getattr(this, overlap_params[0]),
57475755
"other": getattr(other, overlap_params[0]),
57485756
}
57495757

5750-
# Check if we have overlapping data
5751-
axis_inds = {}
5752-
for axis, val_arr in axis_vals.items():
5758+
# Check if we have overlapping data
57535759
both_inds, this_inds, other_inds = np.intersect1d(
5754-
val_arr["this"], val_arr["other"], return_indices=True
5760+
axis_key_arrays[axis]["this"],
5761+
axis_key_arrays[axis]["other"],
5762+
return_indices=True,
57555763
)
5756-
axis_inds[axis] = {
5764+
axis_overlap_inds[axis] = {
57575765
"this": this_inds,
57585766
"other": other_inds,
57595767
"both": both_inds,
57605768
}
57615769

57625770
history_update_string = ""
57635771

5764-
if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5772+
if np.all(
5773+
[len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5774+
):
57655775
# We have overlaps, check that overlapping data is not valid
5766-
multi_axis_params = this._get_multi_axis_params()
57675776
this_test = []
57685777
other_test = []
57695778
for param in multi_axis_params:
@@ -5778,10 +5787,14 @@ def __add__(
57785787
for ax_ind, axis in enumerate(form):
57795788
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57805789
this_index_list.append(
5781-
np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5790+
np.expand_dims(
5791+
axis_overlap_inds[axis]["this"], axis=expand_axes
5792+
)
57825793
)
57835794
other_index_list.append(
5784-
np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5795+
np.expand_dims(
5796+
axis_overlap_inds[axis]["other"], axis=expand_axes
5797+
)
57855798
)
57865799
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
57875800

@@ -5813,7 +5826,9 @@ def __add__(
58135826
"These objects have overlapping data and cannot be combined."
58145827
)
58155828

5816-
new_inds = {}
5829+
# Now actually find which axes are going to be added along
5830+
# other_inds_use have the indices in other that will be added to this
5831+
other_inds_use = {}
58175832
additions = []
58185833
axis_descriptions = {
58195834
"Nblts": "baseline-time",
@@ -5823,30 +5838,30 @@ def __add__(
58235838
# find the indices in "other" but not in "this"
58245839
for axis in axes:
58255840
temp = np.nonzero(
5826-
~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5841+
~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"])
58275842
)[0]
58285843
if len(temp) > 0:
5829-
new_inds[axis] = temp
5844+
other_inds_use[axis] = temp
58305845
# add params associated with the other axes to compatibility_params
58315846
for axis2 in axes:
58325847
if axis2 != axis:
5833-
compatibility_params.extend(axis_params_check[axis2])
5848+
compatibility_params.extend(axis_check_params[axis2])
58345849
additions.append(axis_descriptions[axis])
58355850
else:
5836-
new_inds[axis] = []
5851+
other_inds_use[axis] = []
58375852

58385853
# Actually check compatibility parameters
58395854
for cp in compatibility_params:
58405855
params_match = None
5841-
for axis, check_list in axis_params_check.items():
5856+
for axis, check_list in axis_check_params.items():
58425857
if cp in check_list:
58435858
# only check that overlapping indices match
58445859
this_param = getattr(this, cp)
58455860
this_param_overlap = this_param.get_from_form(
5846-
{axis: axis_inds[axis]["this"]}
5861+
{axis: axis_overlap_inds[axis]["this"]}
58475862
)
58485863
other_param_overlap = getattr(other, cp).get_from_form(
5849-
{axis: axis_inds[axis]["other"]}
5864+
{axis: axis_overlap_inds[axis]["other"]}
58505865
)
58515866
params_match = np.allclose(
58525867
this_param_overlap,
@@ -5878,7 +5893,7 @@ def __add__(
58785893
"Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58795894
"Npols": {"method": "reorder_pols", "parameter": "order"},
58805895
}
5881-
for axis, ind_dict in axis_inds.items():
5896+
for axis, ind_dict in axis_overlap_inds.items():
58825897
if len(ind_dict["this"]) != 0:
58835898
# there is some overlap, so check sorting
58845899
this_argsort = np.argsort(ind_dict["this"])
@@ -5893,16 +5908,22 @@ def __add__(
58935908

58945909
getattr(this, reorder_method[axis]["method"])(**kwargs)
58955910

5896-
# start updating parameters
5897-
new_axis_inds = {}
5911+
# checks are all done, start updating parameters
5912+
# combined_key_arrays has the final key arrays after adding.
5913+
combined_key_arrays = {}
5914+
# order_dict has info about how to sort each axis. Initialize to None
5915+
# for axes that are not added along (so do not need sorting)
58985916
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
58995917
for axis in axes:
5900-
if len(new_inds[axis]) > 0:
5901-
new_axis_inds[axis] = np.concatenate(
5902-
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5918+
if len(other_inds_use[axis]) > 0:
5919+
combined_key_arrays[axis] = np.concatenate(
5920+
(
5921+
axis_key_arrays[axis]["this"],
5922+
axis_key_arrays[axis]["other"][other_inds_use[axis]],
5923+
)
59035924
)
59045925
if axis == "Npols":
5905-
order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis]))
5926+
order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis]))
59065927
elif axis == "Nfreqs" and (
59075928
np.any(np.diff(this.freq_array) < 0)
59085929
or np.any(np.diff(other.freq_array) < 0)
@@ -5913,38 +5934,39 @@ def __add__(
59135934
np.concatenate(
59145935
(
59155936
this.flex_spw_id_array,
5916-
other.flex_spw_id_array[new_inds[axis]],
5937+
other.flex_spw_id_array[other_inds_use[axis]],
59175938
)
59185939
),
59195940
np.concatenate(
5920-
(this.freq_array, other.freq_array[new_inds[axis]])
5941+
(this.freq_array, other.freq_array[other_inds_use[axis]])
59215942
),
59225943
)
59235944
else:
5924-
order_dict[axis] = np.argsort(new_axis_inds[axis])
5945+
order_dict[axis] = np.argsort(combined_key_arrays[axis])
59255946

5926-
# first handle parameters with a single axis
5927-
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5947+
# first handle parameters with a single named axis
5948+
_axis_add_helper(
5949+
this, other, axis, other_inds_use[axis], order_dict[axis]
5950+
)
59285951

59295952
# then pad out parameters with multiple axes
5930-
_axis_pad_helper(this, axis, len(new_inds[axis]))
5953+
_axis_pad_helper(this, axis, len(other_inds_use[axis]))
59315954
else:
5932-
new_axis_inds[axis] = axis_vals[axis]["this"]
5955+
# no add along this axis, so it's the same as what's already on this
5956+
combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59335957

59345958
# Now fill in multidimensional parameters
5959+
# t2o_dict has the mapping of where arrays on other get mapped into
5960+
# this after padding
59355961
t2o_dict = {}
5936-
for axis, inds_dict in axis_vals.items():
5962+
for axis, inds_dict in axis_key_arrays.items():
59375963
t2o_dict[axis] = np.nonzero(
5938-
np.isin(new_axis_inds[axis], inds_dict["other"])
5964+
np.isin(combined_key_arrays[axis], inds_dict["other"])
59395965
)[0]
59405966

5941-
sort_axes = []
5942-
for axis in axes:
5943-
if len(new_inds[axis]) > 0:
5944-
sort_axes.append(axis)
5945-
_fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5967+
_fill_multi_helper(this, other, t2o_dict, order_dict)
59465968

5947-
if len(new_inds["Nfreqs"]) > 0:
5969+
if len(other_inds_use["Nfreqs"]) > 0:
59485970
# We want to preserve per-spw information based on first appearance
59495971
# in the concatenated array.
59505972
unique_index = np.sort(
@@ -5998,7 +6020,7 @@ def __add__(
59986020
)
59996021

60006022
# Reset blt_order if blt axis was added to
6001-
if "Nblts" in sort_axes:
6023+
if order_dict["Nblts"] is not None:
60026024
this.blt_order = ("time", "baseline")
60036025

60046026
this.set_rectangularity(force=True)
@@ -6214,18 +6236,18 @@ def fast_concat(
62146236

62156237
# identify params that are not explicitly included in overlap calc per axis
62166238
axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6217-
axis_params_check = {}
6218-
axis_dict = {}
6239+
axis_check_params = {}
6240+
axis_parameters = {}
62196241
for axis2, ax_shape in axis_shape.items():
6220-
axis_dict[axis2] = this._get_param_axis(ax_shape)
6221-
axis_params_check[axis2] = []
6222-
for param in axis_dict[axis2]:
6242+
axis_parameters[axis2] = this._get_param_axis(ax_shape)
6243+
axis_check_params[axis2] = []
6244+
for param in axis_parameters[axis2]:
62236245
if param not in this._data_params:
6224-
axis_params_check[axis2].append("_" + param)
6246+
axis_check_params[axis2].append("_" + param)
62256247

62266248
for axis2 in axis_shape:
62276249
if axis2 != axis:
6228-
compatibility_params.extend(axis_params_check[axis2])
6250+
compatibility_params.extend(axis_check_params[axis2])
62296251

62306252
axis_descriptions = {
62316253
"blt": "baseline-time",

0 commit comments

Comments
 (0)