Skip to content

Commit a480c8f

Browse files
committed
address review comments
1 parent d7897af commit a480c8f

File tree

3 files changed

+58
-21
lines changed

3 files changed

+58
-21
lines changed

src/pyuvdata/utils/tools.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,11 @@ def _ntimes_to_nblts(uvd):
691691
return np.asarray(inds)
692692

693693

694-
def flt_ind_str_arr(
694+
def float_int_to_str_array(
695695
*,
696696
fltarr: FloatArray,
697697
intarr: IntArray,
698-
flt_tols: tuple[float, float],
698+
flt_tol: tuple[float, float],
699699
flt_first: bool = True,
700700
) -> StrArray:
701701
"""
@@ -707,8 +707,10 @@ def flt_ind_str_arr(
707707
float array to be used in output string array
708708
intarr : np.ndarray of int
709709
integer array to be used in output string array
710-
flt_tols : 2-tuple of float
711-
Tolerances (relative, absolute) to use in formatting the floats as strings.
710+
flt_tol : 2-tuple of float
711+
Absolute tolerance to use in formatting the floats as strings. Note that
712+
this is converted to a decimal place for print formatting, so the precision
713+
might be slightly higher.
712714
flt_first : bool
713715
Whether to put the float first in the out put string or not (if False
714716
the int comes first.)
@@ -718,8 +720,18 @@ def flt_ind_str_arr(
718720
np.ndarray of str
719721
String array that combines the float and integer values, useful for matching.
720722
723+
Examples
724+
--------
725+
>>> float_int_to_str_array(fltarr=[np.pi, np.pi/2], intarr=[1, 2], flt_tol=.01)
726+
array(['3.14_00000001', '1.57_00000002'], dtype='<U13')
727+
728+
>>> float_int_to_str_array(
729+
... fltarr=[np.pi, np.pi/2], intarr=[1, 2], flt_tol=.001, flt_first=False
730+
... )
731+
array(['00000001_3.142', '00000002_1.571'], dtype='<U14')
732+
721733
"""
722-
prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
734+
prec_flt = -1 * np.floor(np.log10(flt_tol)).astype(int)
723735
prec_int = 8
724736
flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr]
725737
int_str_list = [str(intv).zfill(prec_int) for intv in intarr]

src/pyuvdata/uvbase.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,14 @@ def copy(self):
755755
"""
756756
return copy.deepcopy(self)
757757

758-
def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
758+
def _get_uvparam_axis(self, axis_name: str, single_named_axis: bool = False):
759759
"""
760-
Get a mapping of parameters that have a given axis to the axis number.
760+
Get a mapping of properties that have a given axis to the axis number.
761+
762+
This uses the forms of the UVParameter attributes on this object to identify
763+
properties derived from UVParameters that have one or more axes associated
764+
with the axis_name. Any properties with an associated axis appear as keys
765+
in the output dict, with the values giving the associated axis numbers.
761766
762767
Parameters
763768
----------
@@ -770,9 +775,28 @@ def _get_param_axis(self, axis_name: str, single_named_axis: bool = False):
770775
-------
771776
dict
772777
The keys are UVParameter names that have an axis with axis_name
773-
(axis_name appears in their form). The values are a list of the axis
778+
(axis_name appears in their form). The values are an array of the axis
774779
indices where axis_name appears in their form.
775780
781+
Examples
782+
--------
783+
>>> from pyuvdata import UVData
784+
>>> from pyuvdata.datasets import fetch_data
785+
>>> filename = fetch_data("vla_casa_tutorial_uvfits")
786+
>>> uvd = UVData.from_file(filename)
787+
>>> uvd._get_uvparam_axis("Nfreqs")
788+
{'channel_width': array([0]),
789+
'data_array': array([1]),
790+
'flag_array': array([1]),
791+
'flex_spw_id_array': array([0]),
792+
'freq_array': array([0]),
793+
'nsample_array': array([1])}
794+
795+
>>> uvd._get_uvparam_axis("Nfreqs", single_named_axis=True)
796+
{'channel_width': array([0]),
797+
'flex_spw_id_array': array([0]),
798+
'freq_array': array([0])}
799+
776800
"""
777801
ret_dict = {}
778802
for param in self:
@@ -871,7 +895,7 @@ def _axis_add_helper(
871895
self,
872896
other,
873897
axis_name: str,
874-
other_inds: IntArray,
898+
other_inds: IntArray | None = None,
875899
final_order: IntArray | None = None,
876900
):
877901
"""
@@ -883,13 +907,14 @@ def _axis_add_helper(
883907
The UVBase object to be added.
884908
axis_name : str
885909
The axis name (e.g. "Nblts", "Npols").
886-
other_inds : np.ndarray of int
887-
Indices into the other object along this axis to include.
888-
final_order : np.ndarray of int
910+
other_inds : np.ndarray of int, optional
911+
Indices into the other object along this axis to include. If None,
912+
include all indices.
913+
final_order : np.ndarray of int, optional
889914
Final ordering array giving the sort order after concatenation.
890915
891916
"""
892-
update_params = self._get_param_axis(axis_name, single_named_axis=True)
917+
update_params = self._get_uvparam_axis(axis_name, single_named_axis=True)
893918
other_form_dict = {axis_name: other_inds}
894919
for param, axis_list in update_params.items():
895920
axis = axis_list[0]
@@ -917,7 +942,7 @@ def _axis_pad_helper(self, axis_name: str, add_len: int):
917942
The extra length to be padded on for this axis.
918943
919944
"""
920-
update_params = self._get_param_axis(axis_name)
945+
update_params = self._get_uvparam_axis(axis_name)
921946
multi_axis_params = self._get_multi_axis_params()
922947
for param, axis_list in update_params.items():
923948
if param not in multi_axis_params:
@@ -987,7 +1012,7 @@ def _axis_fast_concat_helper(self, other, axis_name: str):
9871012
axis_name : str
9881013
The axis name (e.g. "Nblts", "Npols").
9891014
"""
990-
update_params = self._get_param_axis(axis_name)
1015+
update_params = self._get_uvparam_axis(axis_name)
9911016
for param, axis_list in update_params.items():
9921017
axis = axis_list[0]
9931018
setattr(

src/pyuvdata/uvdata/uvdata.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5366,19 +5366,19 @@ def fix_phase(self, *, use_ant_pos=True):
53665366

53675367
def blt_str_arr(self) -> StrArray:
53685368
"""Create a string array with baseline and time info for matching purposes."""
5369-
return utils.tools.flt_ind_str_arr(
5369+
return utils.tools.float_int_to_str_array(
53705370
fltarr=self.time_array,
53715371
intarr=self.baseline_array,
5372-
flt_tols=self._time_array.tols,
5372+
flt_tol=self._time_array.tols[1] * 0.1,
53735373
flt_first=True,
53745374
)
53755375

53765376
def spw_freq_str_arr(self) -> StrArray:
53775377
"""Create a string array with spw and freq info for matching purposes."""
5378-
return utils.tools.flt_ind_str_arr(
5378+
return utils.tools.float_int_to_str_array(
53795379
fltarr=self.freq_array,
53805380
intarr=self.flex_spw_id_array,
5381-
flt_tols=self._freq_array.tols,
5381+
flt_tol=self._freq_array.tols[1] * 0.1,
53825382
flt_first=False,
53835383
)
53845384

@@ -5551,7 +5551,7 @@ def __add__(
55515551
for axis, info in axis_info.items():
55525552
# get parameters for compatibility checking. Exclude multidimensional
55535553
# parameters which are handled separately later.
5554-
params_this_axis = this._get_param_axis(axis, single_named_axis=True)
5554+
params_this_axis = this._get_uvparam_axis(axis, single_named_axis=True)
55555555
info["check_params"] = []
55565556
for param in params_this_axis:
55575557
# Also exclude parameters that define overlap
@@ -6058,7 +6058,7 @@ def fast_concat(
60586058
# figure out what parameters to check for compatibility -- only worry
60596059
# about single axis params
60606060
for _, info in axis_info.items():
6061-
params_this_axis = this._get_param_axis(
6061+
params_this_axis = this._get_uvparam_axis(
60626062
info["shape"], single_named_axis=True
60636063
)
60646064
info["check_params"] = []

0 commit comments

Comments
 (0)