Skip to content

Commit 31917eb

Browse files
committed
Move functions out of uvdata for better reusability
1 parent 1b6d1e1 commit 31917eb

File tree

4 files changed

+223
-231
lines changed

4 files changed

+223
-231
lines changed

src/pyuvdata/utils/frequency.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from . import tools
1010
from .pol import jstr2num, polstr2num
11+
from .types import FloatArray, IntArray
1112

1213

1314
def _check_flex_spw_contiguous(*, spw_array, flex_spw_id_array, strict=True):
@@ -549,3 +550,41 @@ def _select_freq_helper(
549550
freq_inds = freq_inds.tolist()
550551

551552
return freq_inds, spw_inds, selections
553+
554+
555+
def _add_freq_order(spw_id: IntArray, freq_arr: FloatArray) -> IntArray:
556+
"""
557+
Get the sorting order for the frequency axis after an add.
558+
559+
Sort first by spw then by channel, but don't reorder channels if they are
560+
changing monotonically (all ascending or descending) within the spw.
561+
562+
Parameters
563+
----------
564+
spw_id : np.ndarray of int
565+
SPW id array of combined data to be sorted.
566+
freq_arr : np.ndarray of float
567+
Frequency array of combined data to be sorted.
568+
569+
Returns
570+
-------
571+
f_order : np.ndarray of int
572+
index array giving the sort order.
573+
574+
"""
575+
spws = np.unique(spw_id)
576+
f_order = np.concatenate([np.where(spw_id == spw)[0] for spw in np.unique(spw_id)])
577+
578+
# With spectral windows sorted, check and see if channels within
579+
# windows need sorting. If they are ordered in ascending or descending
580+
# fashion, leave them be. If not, sort in ascending order
581+
for spw in spws:
582+
select_mask = spw_id[f_order] == spw
583+
check_freqs = freq_arr[f_order[select_mask]]
584+
if not np.all(np.diff(check_freqs) > 0) and not np.all(
585+
np.diff(check_freqs) < 0
586+
):
587+
subsort_order = f_order[select_mask]
588+
f_order[select_mask] = subsort_order[np.argsort(check_freqs)]
589+
590+
return f_order

src/pyuvdata/utils/tools.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import numpy as np
1111

12+
from .types import FloatArray, IntArray, StrArray
13+
1214

1315
def _get_iterable(x):
1416
"""Return iterable version of input."""
@@ -687,3 +689,43 @@ def _ntimes_to_nblts(uvd):
687689
inds.append(np.where(unique_t == i)[0][0])
688690

689691
return np.asarray(inds)
692+
693+
694+
def flt_ind_str_arr(
695+
*,
696+
fltarr: FloatArray,
697+
intarr: IntArray,
698+
flt_tols: tuple[float, float],
699+
flt_first: bool = True,
700+
) -> StrArray:
701+
"""
702+
Create a string array built from float and integer arrays for matching.
703+
704+
Parameters
705+
----------
706+
fltarr : np.ndarray of float
707+
float array to be used in output string array
708+
intarr : np.ndarray of int
709+
integer array to be used in output string array
710+
flt_tols : 2-tuple of float
711+
Tolerances (relative, absolute) to use in formatting the floats as strings.
712+
flt_first : bool
713+
Whether to put the float first in the out put string or not (if False
714+
the int comes first.)
715+
716+
Returns
717+
-------
718+
np.ndarray of str
719+
String array that combines the float and integer values, useful for matching.
720+
721+
"""
722+
prec_flt = -2 * np.floor(np.log10(flt_tols[-1])).astype(int)
723+
prec_int = 8
724+
flt_str_list = ["{1:.{0}f}".format(prec_flt, flt) for flt in fltarr]
725+
int_str_list = [str(intv).zfill(prec_int) for intv in intarr]
726+
list_of_lists = []
727+
if flt_first:
728+
list_of_lists = [flt_str_list, int_str_list]
729+
else:
730+
list_of_lists = [int_str_list, flt_str_list]
731+
return np.array(["_".join(zpval) for zpval in zip(*list_of_lists, strict=True)])

src/pyuvdata/uvbase.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from . import __version__, parameter as uvp
1818
from .utils.tools import _get_iterable, slicify
19+
from .utils.types import IntArray
1920

2021
__all__ = ["UVBase"]
2122

@@ -865,3 +866,135 @@ def _select_along_param_axis(self, param_dict: dict):
865866
# here in the case of a repeated param_name in the form.
866867
attr.value = attr.get_from_form(slice_dict)
867868
attr.setter(self)
869+
870+
def _axis_add_helper(
871+
self,
872+
other,
873+
axis_name: str,
874+
other_inds: IntArray,
875+
final_order: IntArray | None = None,
876+
):
877+
"""
878+
Combine UVParameter objects with a single axis along an axis.
879+
880+
Parameters
881+
----------
882+
other : UVBase
883+
The UVBase object to be added.
884+
axis_name : str
885+
The axis name (e.g. "Nblts", "Npols").
886+
other_inds : np.ndarray of int
887+
Indices into the other object along this axis to include.
888+
final_order : np.ndarray of int
889+
Final ordering array giving the sort order after concatenation.
890+
891+
"""
892+
update_params = self._get_param_axis(axis_name, single_named_axis=True)
893+
other_form_dict = {axis_name: other_inds}
894+
for param, axis_list in update_params.items():
895+
axis = axis_list[0]
896+
new_array = np.concatenate(
897+
[
898+
getattr(self, param),
899+
getattr(other, "_" + param).get_from_form(other_form_dict),
900+
],
901+
axis=axis,
902+
)
903+
if final_order is not None:
904+
new_array = np.take(new_array, final_order, axis=axis)
905+
906+
setattr(self, param, new_array)
907+
908+
def _axis_pad_helper(self, axis_name: str, add_len: int):
909+
"""
910+
Pad out UVParameter objects with multiple dimensions along an axis.
911+
912+
Parameters
913+
----------
914+
axis_name : str
915+
The axis name (e.g. "Nblts", "Npols").
916+
add_len : int
917+
The extra length to be padded on for this axis.
918+
919+
"""
920+
update_params = self._get_param_axis(axis_name)
921+
multi_axis_params = self._get_multi_axis_params()
922+
for param, axis_list in update_params.items():
923+
if param not in multi_axis_params:
924+
continue
925+
this_param_shape = getattr(self, param).shape
926+
this_param_type = getattr(self, "_" + param).expected_type
927+
bool_type = this_param_type is bool or bool in this_param_type
928+
pad_shape = list(this_param_shape)
929+
for ax in axis_list:
930+
pad_shape[ax] = add_len
931+
if bool_type:
932+
pad_array = np.ones(tuple(pad_shape), dtype=bool)
933+
else:
934+
pad_array = np.zeros(tuple(pad_shape))
935+
new_array = np.concatenate([getattr(self, param), pad_array], axis=ax)
936+
if bool_type:
937+
new_array = new_array.astype(np.bool_)
938+
setattr(self, param, new_array)
939+
940+
def _fill_multi_helper(self, other, t2o_dict: dict, order_dict: dict):
941+
"""
942+
Fill UVParameter objects with multiple dimensions from the right side object.
943+
944+
Parameters
945+
----------
946+
other : UVBase
947+
The UVBase object to be added.
948+
t2o_dict : dict
949+
dict giving the indices in the left object to be filled from the right
950+
object for each axis (keys are axes, values are index arrays).
951+
order_dict : dict
952+
dict giving the final sort indices for each axis (keys are axes, values
953+
are index arrays for sorting).
954+
955+
"""
956+
multi_axis_params = self._get_multi_axis_params()
957+
for param in multi_axis_params:
958+
form = getattr(self, "_" + param).form
959+
index_list = []
960+
for axis in form:
961+
index_list.append(t2o_dict[axis])
962+
new_arr = getattr(self, param)
963+
new_arr[np.ix_(*index_list)] = getattr(other, param)
964+
setattr(self, param, new_arr)
965+
966+
# Fix ordering
967+
for axis_ind, axis in enumerate(form):
968+
if order_dict[axis] is not None:
969+
unique_order_diffs = np.unique(np.diff(order_dict[axis]))
970+
if np.array_equal(unique_order_diffs, np.array([1])):
971+
# everything is already in order
972+
continue
973+
setattr(
974+
self,
975+
param,
976+
np.take(getattr(self, param), order_dict[axis], axis=axis_ind),
977+
)
978+
979+
def _axis_fast_concat_helper(self, other, axis_name: str):
980+
"""
981+
Concatenate UVParameter objects along an axis assuming no overlap.
982+
983+
Parameters
984+
----------
985+
other : UVBase
986+
The UVBase object to be added.
987+
axis_name : str
988+
The axis name (e.g. "Nblts", "Npols").
989+
"""
990+
update_params = self._get_param_axis(axis_name)
991+
for param, axis_list in update_params.items():
992+
axis = axis_list[0]
993+
setattr(
994+
self,
995+
param,
996+
np.concatenate(
997+
[getattr(self, param)] + [getattr(obj, param) for obj in other],
998+
axis=axis,
999+
),
1000+
)

0 commit comments

Comments
 (0)