Skip to content

Commit 1b6d1e1

Browse files
committed
Better comments and variable names
1 parent 572a748 commit 1b6d1e1

File tree

2 files changed

+99
-77
lines changed

2 files changed

+99
-77
lines changed

src/pyuvdata/uvbase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
776776
ret_dict = {}
777777
for param in self:
778778
# For each attribute, if the value is None, then bail, otherwise
779-
# attempt to figure out along which axis ind_arr will apply.
779+
# find the axis number(s) with the named shape.
780780

781781
attr = getattr(self, param)
782782
if (
@@ -791,8 +791,8 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
791791
continue
792792

793793
# Only look at where form is a tuple, since that's the only case we
794-
# can have a dynamically defined shape. Note that index doesn't work
795-
# here in the case of a repeated param_name in the form.
794+
# can have a dynamically defined shape. Handle a repeated
795+
# param_name in the form.
796796
ret_dict[attr.name] = np.nonzero(np.asarray(attr.form) == axis_name)[0]
797797
return ret_dict
798798

src/pyuvdata/uvdata/uvdata.py

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def _axis_pad_helper(this: UVData, axis_name: str, add_len: int):
195195
setattr(this, param, new_array)
196196

197197

198-
def _fill_multi_helper(
199-
this: UVData, other: UVData, t2o_dict: dict, sort_axes: list[str], order_dict: dict
200-
):
198+
def _fill_multi_helper(this: UVData, other: UVData, t2o_dict: dict, order_dict: dict):
201199
"""
202200
Fill UVParameter objects with multiple dimensions from the right side object.
203201

@@ -210,8 +208,6 @@ def _fill_multi_helper(
210208
t2o_dict : dict
211209
dict giving the indices in the left object to be filled from the right
212210
object for each axis (keys are axes, values are index arrays).
213-
sort_axes : list of str
214-
The axes that need to be sorted along.
215211
order_dict : dict
216212
dict giving the final sort indices for each axis (keys are axes, values
217213
are index arrays for sorting).
@@ -229,7 +225,7 @@ def _fill_multi_helper(
229225

230226
# Fix ordering
231227
for axis_ind, axis in enumerate(form):
232-
if axis in sort_axes:
228+
if order_dict[axis] is not None:
233229
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
234230
if np.array_equal(unique_order_diffs, np.array([1])):
235231
# everything is already in order
@@ -5709,59 +5705,72 @@ def __add__(
57095705
# Define parameters that must be the same to add objects
57105706
compatibility_params = ["_vis_units"]
57115707

5712-
# identify params that are not explicitly included in overlap calc per axis
57135708
axes = ["Nblts", "Nfreqs", "Npols"]
5714-
axis_params_check = {}
5715-
axis_overlap_params = {
5709+
# axis_key_params defines which parameters to use as the defining
5710+
# parameters along each axis. These are used to identify overlapping data.
5711+
axis_key_params = {
57165712
"Nblts": ["time_array", "baseline_array"],
57175713
"Nfreqs": ["freq_array", "flex_spw_id_array"],
57185714
"Npols": ["polarization_array"],
57195715
}
5720-
axis_combined_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5721-
axis_dict = {}
5722-
for axis in axes:
5723-
axis_dict[axis] = this._get_param_axis(axis)
5724-
axis_params_check[axis] = []
5725-
for param in axis_dict[axis]:
5716+
# specify a function to form a combined string if there are multiple
5717+
# key arrays (e.g. baseline-time, spw-freq)
5718+
axis_key_func = {"Nblts": "blt_str_arr", "Nfreqs": "spw_freq_str_arr"}
5719+
multi_axis_params = this._get_multi_axis_params()
5720+
# axis_parameters gives parameters whose form contains each axis
5721+
axis_parameters = {}
5722+
# axis_check_params gives parameters that should be checked if adding
5723+
# along other axes
5724+
axis_check_params = {}
5725+
# axis_key_arrays gives the arrays to use for checking for overlap per axis
5726+
axis_key_arrays = {}
5727+
# axis_overlap_inds has the outcomes of np.intersect1d on the
5728+
# axis_key_arrays per axis. So it has the both/this inds/other inds
5729+
# for any overlaps.
5730+
axis_overlap_inds = {}
5731+
for axis, overlap_params in axis_key_params.items():
5732+
axis_parameters[axis] = this._get_param_axis(axis)
5733+
axis_check_params[axis] = []
5734+
for param in axis_parameters[axis]:
5735+
# get parameters for compatibility checking. Exclude parameters
5736+
# that define overlap and multidimensional parameters which are
5737+
# handled separately later.
57265738
if (
5727-
param not in this._data_params
5728-
and param not in axis_overlap_params[axis]
5739+
param not in multi_axis_params
5740+
and param not in axis_key_params[axis]
57295741
):
5730-
axis_params_check[axis].append("_" + param)
5742+
axis_check_params[axis].append("_" + param)
57315743

5732-
# build this/other arrays for checking for overlap.
5733-
# Use a combined string if there are multiple arrays defining overlap
5734-
# (e.g. baseline-time, spw-freq)
5735-
axis_vals = {}
5736-
for axis, overlap_params in axis_overlap_params.items():
5744+
# build this/other arrays for checking for overlap.
57375745
if len(overlap_params) > 1:
5738-
axis_vals[axis] = {
5739-
"this": getattr(this, axis_combined_func[axis])(),
5740-
"other": getattr(other, axis_combined_func[axis])(),
5746+
axis_key_arrays[axis] = {
5747+
"this": getattr(this, axis_key_func[axis])(),
5748+
"other": getattr(other, axis_key_func[axis])(),
57415749
}
57425750
else:
5743-
axis_vals[axis] = {
5751+
axis_key_arrays[axis] = {
57445752
"this": getattr(this, overlap_params[0]),
57455753
"other": getattr(other, overlap_params[0]),
57465754
}
57475755

5748-
# Check if we have overlapping data
5749-
axis_inds = {}
5750-
for axis, val_arr in axis_vals.items():
5756+
# Check if we have overlapping data
57515757
both_inds, this_inds, other_inds = np.intersect1d(
5752-
val_arr["this"], val_arr["other"], return_indices=True
5758+
axis_key_arrays[axis]["this"],
5759+
axis_key_arrays[axis]["other"],
5760+
return_indices=True,
57535761
)
5754-
axis_inds[axis] = {
5762+
axis_overlap_inds[axis] = {
57555763
"this": this_inds,
57565764
"other": other_inds,
57575765
"both": both_inds,
57585766
}
57595767

57605768
history_update_string = ""
57615769

5762-
if np.all([len(axis_inds[axis]["both"]) > 0 for axis in axis_inds]):
5770+
if np.all(
5771+
[len(axis_overlap_inds[axis]["both"]) > 0 for axis in axis_overlap_inds]
5772+
):
57635773
# We have overlaps, check that overlapping data is not valid
5764-
multi_axis_params = this._get_multi_axis_params()
57655774
this_test = []
57665775
other_test = []
57675776
for param in multi_axis_params:
@@ -5776,10 +5785,14 @@ def __add__(
57765785
for ax_ind, axis in enumerate(form):
57775786
expand_axes = [ax for ax in range(len(form)) if ax != ax_ind]
57785787
this_index_list.append(
5779-
np.expand_dims(axis_inds[axis]["this"], axis=expand_axes)
5788+
np.expand_dims(
5789+
axis_overlap_inds[axis]["this"], axis=expand_axes
5790+
)
57805791
)
57815792
other_index_list.append(
5782-
np.expand_dims(axis_inds[axis]["other"], axis=expand_axes)
5793+
np.expand_dims(
5794+
axis_overlap_inds[axis]["other"], axis=expand_axes
5795+
)
57835796
)
57845797
this_inds = np.ravel_multi_index(this_index_list, this_shape).flatten()
57855798

@@ -5811,7 +5824,9 @@ def __add__(
58115824
"These objects have overlapping data and cannot be combined."
58125825
)
58135826

5814-
new_inds = {}
5827+
# Now actually find which axes are going to be added along
5828+
# other_inds_use have the indices in other that will be added to this
5829+
other_inds_use = {}
58155830
additions = []
58165831
axis_descriptions = {
58175832
"Nblts": "baseline-time",
@@ -5821,30 +5836,30 @@ def __add__(
58215836
# find the indices in "other" but not in "this"
58225837
for axis in axes:
58235838
temp = np.nonzero(
5824-
~np.isin(axis_vals[axis]["other"], axis_vals[axis]["this"])
5839+
~np.isin(axis_key_arrays[axis]["other"], axis_key_arrays[axis]["this"])
58255840
)[0]
58265841
if len(temp) > 0:
5827-
new_inds[axis] = temp
5842+
other_inds_use[axis] = temp
58285843
# add params associated with the other axes to compatibility_params
58295844
for axis2 in axes:
58305845
if axis2 != axis:
5831-
compatibility_params.extend(axis_params_check[axis2])
5846+
compatibility_params.extend(axis_check_params[axis2])
58325847
additions.append(axis_descriptions[axis])
58335848
else:
5834-
new_inds[axis] = []
5849+
other_inds_use[axis] = []
58355850

58365851
# Actually check compatibility parameters
58375852
for cp in compatibility_params:
58385853
params_match = None
5839-
for axis, check_list in axis_params_check.items():
5854+
for axis, check_list in axis_check_params.items():
58405855
if cp in check_list:
58415856
# only check that overlapping indices match
58425857
this_param = getattr(this, cp)
58435858
this_param_overlap = this_param.get_from_form(
5844-
{axis: axis_inds[axis]["this"]}
5859+
{axis: axis_overlap_inds[axis]["this"]}
58455860
)
58465861
other_param_overlap = getattr(other, cp).get_from_form(
5847-
{axis: axis_inds[axis]["other"]}
5862+
{axis: axis_overlap_inds[axis]["other"]}
58485863
)
58495864
params_match = np.allclose(
58505865
this_param_overlap,
@@ -5876,7 +5891,7 @@ def __add__(
58765891
"Nfreqs": {"method": "reorder_freqs", "parameter": "channel_order"},
58775892
"Npols": {"method": "reorder_pols", "parameter": "order"},
58785893
}
5879-
for axis, ind_dict in axis_inds.items():
5894+
for axis, ind_dict in axis_overlap_inds.items():
58805895
if len(ind_dict["this"]) != 0:
58815896
# there is some overlap, so check sorting
58825897
this_argsort = np.argsort(ind_dict["this"])
@@ -5891,16 +5906,22 @@ def __add__(
58915906

58925907
getattr(this, reorder_method[axis]["method"])(**kwargs)
58935908

5894-
# start updating parameters
5895-
new_axis_inds = {}
5909+
# checks are all done, start updating parameters
5910+
# combined_key_arrays has the final key arrays after adding.
5911+
combined_key_arrays = {}
5912+
# order_dict has info about how to sort each axis. Initialize to None
5913+
# for axes that are not added along (so do not need sorting)
58965914
order_dict = {"Nblts": None, "Nfreqs": None, "Npols": None}
58975915
for axis in axes:
5898-
if len(new_inds[axis]) > 0:
5899-
new_axis_inds[axis] = np.concatenate(
5900-
(axis_vals[axis]["this"], axis_vals[axis]["other"][new_inds[axis]])
5916+
if len(other_inds_use[axis]) > 0:
5917+
combined_key_arrays[axis] = np.concatenate(
5918+
(
5919+
axis_key_arrays[axis]["this"],
5920+
axis_key_arrays[axis]["other"][other_inds_use[axis]],
5921+
)
59015922
)
59025923
if axis == "Npols":
5903-
order_dict[axis] = np.argsort(np.abs(new_axis_inds[axis]))
5924+
order_dict[axis] = np.argsort(np.abs(combined_key_arrays[axis]))
59045925
elif axis == "Nfreqs" and (
59055926
np.any(np.diff(this.freq_array) < 0)
59065927
or np.any(np.diff(other.freq_array) < 0)
@@ -5911,38 +5932,39 @@ def __add__(
59115932
np.concatenate(
59125933
(
59135934
this.flex_spw_id_array,
5914-
other.flex_spw_id_array[new_inds[axis]],
5935+
other.flex_spw_id_array[other_inds_use[axis]],
59155936
)
59165937
),
59175938
np.concatenate(
5918-
(this.freq_array, other.freq_array[new_inds[axis]])
5939+
(this.freq_array, other.freq_array[other_inds_use[axis]])
59195940
),
59205941
)
59215942
else:
5922-
order_dict[axis] = np.argsort(new_axis_inds[axis])
5943+
order_dict[axis] = np.argsort(combined_key_arrays[axis])
59235944

5924-
# first handle parameters with a single axis
5925-
_axis_add_helper(this, other, axis, new_inds[axis], order_dict[axis])
5945+
# first handle parameters with a single named axis
5946+
_axis_add_helper(
5947+
this, other, axis, other_inds_use[axis], order_dict[axis]
5948+
)
59265949

59275950
# then pad out parameters with multiple axes
5928-
_axis_pad_helper(this, axis, len(new_inds[axis]))
5951+
_axis_pad_helper(this, axis, len(other_inds_use[axis]))
59295952
else:
5930-
new_axis_inds[axis] = axis_vals[axis]["this"]
5953+
# no add along this axis, so it's the same as what's already on this
5954+
combined_key_arrays[axis] = axis_key_arrays[axis]["this"]
59315955

59325956
# Now fill in multidimensional parameters
5957+
# t2o_dict has the mapping of where arrays on other get mapped into
5958+
# this after padding
59335959
t2o_dict = {}
5934-
for axis, inds_dict in axis_vals.items():
5960+
for axis, inds_dict in axis_key_arrays.items():
59355961
t2o_dict[axis] = np.nonzero(
5936-
np.isin(new_axis_inds[axis], inds_dict["other"])
5962+
np.isin(combined_key_arrays[axis], inds_dict["other"])
59375963
)[0]
59385964

5939-
sort_axes = []
5940-
for axis in axes:
5941-
if len(new_inds[axis]) > 0:
5942-
sort_axes.append(axis)
5943-
_fill_multi_helper(this, other, t2o_dict, sort_axes, order_dict)
5965+
_fill_multi_helper(this, other, t2o_dict, order_dict)
59445966

5945-
if len(new_inds["Nfreqs"]) > 0:
5967+
if len(other_inds_use["Nfreqs"]) > 0:
59465968
# We want to preserve per-spw information based on first appearance
59475969
# in the concatenated array.
59485970
unique_index = np.sort(
@@ -5996,7 +6018,7 @@ def __add__(
59966018
)
59976019

59986020
# Reset blt_order if blt axis was added to
5999-
if "Nblts" in sort_axes:
6021+
if order_dict["Nblts"] is not None:
60006022
this.blt_order = ("time", "baseline")
60016023

60026024
this.set_rectangularity(force=True)
@@ -6212,18 +6234,18 @@ def fast_concat(
62126234

62136235
# identify params that are not explicitly included in overlap calc per axis
62146236
axis_shape = {"blt": "Nblts", "freq": "Nfreqs", "polarization": "Npols"}
6215-
axis_params_check = {}
6216-
axis_dict = {}
6237+
axis_check_params = {}
6238+
axis_parameters = {}
62176239
for axis2, ax_shape in axis_shape.items():
6218-
axis_dict[axis2] = this._get_param_axis(ax_shape)
6219-
axis_params_check[axis2] = []
6220-
for param in axis_dict[axis2]:
6240+
axis_parameters[axis2] = this._get_param_axis(ax_shape)
6241+
axis_check_params[axis2] = []
6242+
for param in axis_parameters[axis2]:
62216243
if param not in this._data_params:
6222-
axis_params_check[axis2].append("_" + param)
6244+
axis_check_params[axis2].append("_" + param)
62236245

62246246
for axis2 in axis_shape:
62256247
if axis2 != axis:
6226-
compatibility_params.extend(axis_params_check[axis2])
6248+
compatibility_params.extend(axis_check_params[axis2])
62276249

62286250
axis_descriptions = {
62296251
"blt": "baseline-time",

0 commit comments

Comments
 (0)