Skip to content

Commit 8a8515b

Browse files
committed
Update _data.py
1 parent 1797fd7 commit 8a8515b

1 file changed

Lines changed: 11 additions & 34 deletions

File tree

WrightTools/data/_data.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -717,22 +717,20 @@ def gradient(self, axis, *, channel=0):
717717
else:
718718
raise wt_exceptions.TypeError("axis: expected {int, str}, got %s" % type(axis))
719719

720-
channel_index = wt_kit.get_index(self.channel_names, channel)
721-
channel = self.channel_names[channel_index]
720+
channel = self.get_channel(channel)
722721

723-
if self[channel].shape[axis_index] == 1:
722+
if channel.shape[axis_index] == 1:
724723
raise wt_exceptions.ValueError(
725724
"Channel '{}' has a single point along Axis '{}', cannot compute gradient".format(
726725
channel, axis
727726
)
728727
)
729-
rtype = np.result_type(self[channel].dtype, float)
728+
rtype = np.result_type(channel.dtype, float)
730729
new = self.create_channel(
731730
"{}_{}_gradient".format(channel, axis),
732-
values=np.empty(self[channel].shape, dtype=rtype),
731+
values=np.empty(channel.shape, dtype=rtype),
733732
)
734733

735-
channel = self[channel]
736734
if axis == axis_index:
737735
new[:] = np.gradient(channel[:], axis=axis_index)
738736
else:
@@ -1268,13 +1266,7 @@ def get_nadir(self, channel=0) -> tuple:
12681266
Coordinates in units for each axis.
12691267
"""
12701268
# get channel
1271-
if isinstance(channel, int):
1272-
channel_index = channel
1273-
elif isinstance(channel, str):
1274-
channel_index = self.channel_names.index(channel)
1275-
else:
1276-
raise TypeError("channel: expected {int, str}, got %s" % type(channel))
1277-
channel = self.channels[channel_index]
1269+
channel = self.get_channel(channel)
12781270
# get indicies
12791271
idx = channel.argmin()
12801272
# finish
@@ -1294,13 +1286,7 @@ def get_zenith(self, channel=0) -> tuple:
12941286
Coordinates in units for each axis.
12951287
"""
12961288
# get channel
1297-
if isinstance(channel, int):
1298-
channel_index = channel
1299-
elif isinstance(channel, str):
1300-
channel_index = self.channel_names.index(channel)
1301-
else:
1302-
raise TypeError("channel: expected {int, str}, got %s" % type(channel))
1303-
channel = self.channels[channel_index]
1289+
channel:Channel = self.get_channel(channel)
13041290
# get indicies
13051291
idx = channel.argmax()
13061292
# finish
@@ -1336,15 +1322,8 @@ def heal(self, channel=0, method="linear", fill_value=np.nan, verbose=True):
13361322
warnings.warn("heal", category=wt_exceptions.EntireDatasetInMemoryWarning)
13371323
timer = wt_kit.Timer(verbose=False)
13381324
with timer:
1339-
# channel
1340-
if isinstance(channel, int):
1341-
channel_index = channel
1342-
elif isinstance(channel, str):
1343-
channel_index = self.channel_names.index(channel)
1344-
else:
1345-
raise TypeError("channel: expected {int, str}, got %s" % type(channel))
1346-
channel = self.channels[channel_index]
1347-
values = self.channels[channel_index][:]
1325+
channel = self.get_channel(channel)
1326+
values = channel[:]
13481327
points = [axis[:] for axis in self._axes]
13491328
xi = tuple(np.meshgrid(*points, indexing="ij"))
13501329
# 'undo' gridding
@@ -1358,7 +1337,7 @@ def heal(self, channel=0, method="linear", fill_value=np.nan, verbose=True):
13581337
tup = tuple([arr[i] for i in range(len(arr) - 1)])
13591338
# grid data
13601339
out = griddata(tup, arr[-1], xi, method=method, fill_value=fill_value)
1361-
self.channels[channel_index][:] = out
1340+
channel[:] = out
13621341
# print
13631342
if verbose:
13641343
print(
@@ -1384,8 +1363,7 @@ def level(self, channel, axis, npts, *, verbose=True):
13841363
Toggle talkback. Default is True.
13851364
"""
13861365
warnings.warn("level", category=wt_exceptions.EntireDatasetInMemoryWarning)
1387-
channel_index = wt_kit.get_index(self.channel_names, channel)
1388-
channel = self.channels[channel_index]
1366+
channel = self.get_channel(channel)
13891367
# verify npts not zero
13901368
npts = int(npts)
13911369
if npts == 0:
@@ -1443,8 +1421,7 @@ def map_variable(
14431421
New data object.
14441422
"""
14451423
# get variable index
1446-
variable_index = wt_kit.get_index(self.variable_names, variable)
1447-
variable = self.variables[variable_index]
1424+
variable = self.get_var(variable)
14481425
# get points
14491426
if isinstance(points, int):
14501427
points = np.linspace(variable.min(), variable.max(), points)

0 commit comments

Comments
 (0)