Skip to content

Commit 6d9b8e4

Browse files
authored
python(bug): Fix scaling issues with TDMS import (#422)
1 parent 3d3cfee commit 6d9b8e4

2 files changed

Lines changed: 194 additions & 18 deletions

File tree

python/lib/sift_py/data_import/_tdms_test.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
import io
12
import json
23
from typing import Any, Dict, List, Optional
34

5+
import numpy as np
46
import pandas as pd
57
import pytest
6-
from nptdms import TdmsFile, types # type: ignore
8+
from nptdms import ( # type: ignore
9+
ChannelObject,
10+
GroupObject,
11+
RootObject,
12+
TdmsFile,
13+
TdmsWriter,
14+
types,
15+
)
716
from pytest_mock import MockFixture
817
from sift.metadata.v1.metadata_pb2 import MetadataKeyType
918

@@ -24,8 +33,28 @@ def __init__(
2433
self.group_name: str = group_name
2534
self.properties: Optional[Dict[str, str]] = properties or {}
2635
self.data: Optional[List[int]] = data or []
36+
self.raw_data = self.data
2737
self.data_type: type = data_type
2838

39+
tdms_to_numpy = {
40+
types.Int8: np.dtype(np.int8),
41+
types.Int16: np.dtype(np.int16),
42+
types.Int32: np.dtype(np.int32),
43+
types.Int64: np.dtype(np.int64),
44+
types.Uint8: np.dtype(np.uint8),
45+
types.Uint16: np.dtype(np.uint16),
46+
types.Uint32: np.dtype(np.uint32),
47+
types.Uint64: np.dtype(np.uint64),
48+
types.SingleFloat: np.dtype(np.float32),
49+
types.DoubleFloat: np.dtype(np.float64),
50+
types.Boolean: np.dtype(np.bool_),
51+
types.String: np.dtype(np.str_),
52+
types.TimeStamp: None,
53+
types.ComplexSingleFloat: np.dtype(np.complex64),
54+
types.ComplexDoubleFloat: np.dtype(np.complex128),
55+
}
56+
self.dtype = tdms_to_numpy[self.data_type]
57+
2958

3059
class MockTdmsGroup:
3160
def __init__(self, name, channels: List[MockTdmsChannel]):
@@ -92,6 +121,39 @@ def mock_waveform_tdms_file():
92121
return MockTdmsFile(mock_tdms_groups)
93122

94123

124+
@pytest.fixture
125+
def waveform_tdms_file_with_scaling():
126+
group = GroupObject("Group 0")
127+
valid_channels = [
128+
ChannelObject(
129+
group="Group 0",
130+
channel=f"Test/channel_{c}",
131+
data=[1, 2, 3],
132+
properties={
133+
"wf_start_time": np.datetime64("2025-10-19T00:00:00.000000"),
134+
"wf_increment": 0.1,
135+
"wf_start_offset": 0,
136+
"extra": "info",
137+
"NI_Scaling_Status": "scaled" if c == 0 else "unscaled",
138+
"NI_Number_Of_Scales": 1,
139+
"NI_Scale[0]_Scale_Type": "Linear",
140+
"NI_Scale[0]_Linear_Slope": 1.5,
141+
"NI_Scale[0]_Linear_Y_Intercept": 10,
142+
"NI_Scale[0]_Linear_Input_Source": 0xFFFFFFFF,
143+
},
144+
)
145+
for c in range(3)
146+
]
147+
148+
file_bytes = io.BytesIO()
149+
with TdmsWriter(file_bytes) as tdms_writer:
150+
root_object = RootObject({})
151+
tdms_writer.write_segment([root_object] + [group] + valid_channels)
152+
153+
file_bytes.seek(0)
154+
return TdmsFile(file_bytes)
155+
156+
95157
@pytest.fixture
96158
def mock_time_channel_tdms_file():
97159
mock_tdms_groups = [
@@ -586,6 +648,7 @@ def test_tdms_upload_unknown_data_type(mocker: MockFixture, mock_waveform_tdms_f
586648
mock_requests_post.return_value = MockResponse()
587649

588650
mock_waveform_tdms_file.groups()[0].channels()[0].data_type = types.ComplexDoubleFloat
651+
mock_waveform_tdms_file.groups()[0].channels()[0].dtype = np.dtype(np.complex128)
589652
mocker.patch("sift_py.data_import.tdms.TdmsFile").return_value = mock_waveform_tdms_file
590653

591654
svc = TdmsUploadService(rest_config)
@@ -887,3 +950,111 @@ def test_tdms_upload_service_upload_with_metadata_run_id(
887950
# Metadata keys should match those in the mock_tdms_file properties
888951
keys = [md["key"]["name"] for md in patch_data["run"]["metadata"]]
889952
assert set(keys) == set(mock_waveform_tdms_file.properties.keys())
953+
954+
955+
def test_waveform_tdms_with_scaling_upload_success(
956+
mocker: MockFixture, waveform_tdms_file_with_scaling: MockTdmsFile
957+
):
958+
mock_path_is_file = mocker.patch("sift_py.data_import.tdms.Path.is_file")
959+
mock_path_is_file.return_value = True
960+
961+
mock_path_getsize = mocker.patch("sift_py.data_import.csv.os.path.getsize")
962+
mock_path_getsize.return_value = 10
963+
964+
mock_requests_post = mocker.patch("sift_py.rest.requests.Session.post")
965+
mock_requests_post.return_value = MockResponse()
966+
967+
def mock_tdms_file_constructor(path):
968+
"""The first call should always return the mocked object since
969+
it is mocking a call to open the orignal tdms file.
970+
971+
The second call should return a real TdmsFile since the unit
972+
test will actually create one with filtered channels.
973+
"""
974+
if path == "some_tdms.tdms":
975+
return waveform_tdms_file_with_scaling
976+
else:
977+
return TdmsFile(path)
978+
979+
mocker.patch("sift_py.data_import.tdms.TdmsFile", mock_tdms_file_constructor)
980+
981+
# Create a mock file so we can cpature the data that's written
982+
class MockNamedTemporaryFile:
983+
def __init__(self, **kwargs):
984+
self.data = ""
985+
self.name = "filename.csv"
986+
987+
def write(self, data: str):
988+
self.data += data
989+
return len(data)
990+
991+
def close(self):
992+
pass
993+
994+
def __enter__(self):
995+
return self
996+
997+
def __exit__(self, exc_type, exc_val, exc_tb):
998+
pass
999+
1000+
mock_temp_files = []
1001+
1002+
def mock_temp_file_constructor(**kwargs):
1003+
mf = MockNamedTemporaryFile(**kwargs)
1004+
mock_temp_files.append(mf)
1005+
return mf
1006+
1007+
mocker.patch("sift_py.data_import.tdms.NamedTemporaryFile", mock_temp_file_constructor)
1008+
1009+
svc = TdmsUploadService(rest_config)
1010+
1011+
def get_csv_config(mock, n):
1012+
"""Return the CSV config that was created and uploaded under the hood."""
1013+
return json.loads(mock_requests_post.call_args_list[n].kwargs["data"])["csv_config"]
1014+
1015+
# Test without grouping
1016+
svc.upload("some_tdms.tdms", "asset_name")
1017+
config = get_csv_config(mock_requests_post, 0)
1018+
expected_config: Dict[str, Any] = {
1019+
"asset_name": "asset_name",
1020+
"run_name": "",
1021+
"run_id": "",
1022+
"first_data_row": 2,
1023+
"time_column": {
1024+
"format": "TIME_FORMAT_ABSOLUTE_DATETIME",
1025+
"column_number": 1,
1026+
"relative_start_time": None,
1027+
},
1028+
"data_columns": {},
1029+
}
1030+
for i in range(3):
1031+
expected_config["data_columns"][str(2 + i)] = {
1032+
"name": f"Test/channel_{i}",
1033+
"data_type": "CHANNEL_DATA_TYPE_INT_32" if i == 0 else "CHANNEL_DATA_TYPE_DOUBLE",
1034+
"units": "",
1035+
"description": "",
1036+
"enum_types": [],
1037+
"bit_field_elements": [],
1038+
}
1039+
assert config == expected_config
1040+
1041+
# Create a pandas DataFrame with the expected resulting CSV data
1042+
# Values should be scaled correctly.
1043+
df = pd.DataFrame(
1044+
{
1045+
"": [
1046+
np.datetime64("2025-10-19T00:00:00.000000"),
1047+
np.datetime64("2025-10-19T00:00:00.100000"),
1048+
np.datetime64("2025-10-19T00:00:00.200000"),
1049+
],
1050+
"/'Group 0'/'Test/channel_0'": [1, 2, 3],
1051+
"/'Group 0'/'Test/channel_1'": [11.5, 13.0, 14.5],
1052+
"/'Group 0'/'Test/channel_2'": [11.5, 13.0, 14.5],
1053+
}
1054+
)
1055+
1056+
csv_buffer = io.StringIO()
1057+
df.to_csv(csv_buffer, index=False)
1058+
csv_content = csv_buffer.getvalue()
1059+
1060+
assert mock_temp_files[0].data == csv_content

python/lib/sift_py/data_import/tdms.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,21 @@
3636
from sift_py.ingestion.channel import ChannelDataType
3737
from sift_py.rest import SiftRestConfig
3838

39-
TDMS_TO_SIFT_TYPES = {
40-
types.Boolean: ChannelDataType.BOOL,
41-
types.Int8: ChannelDataType.INT_32,
42-
types.Int16: ChannelDataType.INT_32,
43-
types.Int32: ChannelDataType.INT_32,
44-
types.Int64: ChannelDataType.INT_64,
45-
types.Uint8: ChannelDataType.UINT_32,
46-
types.Uint16: ChannelDataType.UINT_32,
47-
types.Uint32: ChannelDataType.UINT_32,
48-
types.Uint64: ChannelDataType.UINT_64,
49-
types.SingleFloat: ChannelDataType.FLOAT,
50-
types.DoubleFloat: ChannelDataType.DOUBLE,
51-
types.String: ChannelDataType.STRING,
39+
# Mapping from numpy data types to Sift ChannelDataType
40+
NUMPY_TO_SIFT_TYPES = {
41+
np.bool_: ChannelDataType.BOOL,
42+
np.int8: ChannelDataType.INT_32,
43+
np.int16: ChannelDataType.INT_32,
44+
np.int32: ChannelDataType.INT_32,
45+
np.int64: ChannelDataType.INT_64,
46+
np.uint8: ChannelDataType.UINT_32,
47+
np.uint16: ChannelDataType.UINT_32,
48+
np.uint32: ChannelDataType.UINT_32,
49+
np.uint64: ChannelDataType.UINT_64,
50+
np.float32: ChannelDataType.FLOAT,
51+
np.float64: ChannelDataType.DOUBLE,
52+
np.str_: ChannelDataType.STRING,
53+
np.object_: ChannelDataType.STRING,
5254
}
5355

5456

@@ -65,7 +67,9 @@ class TdmsTimeFormat(Enum):
6567
# Implements the same interface as TdmsChannel. Allows us to create
6668
# TdmsChannel like objects without having to save and read the channels to
6769
# a file.
68-
_TdmsChannel = namedtuple("_TdmsChannel", ["group_name", "name", "data_type", "data", "properties"])
70+
_TdmsChannel = namedtuple(
71+
"_TdmsChannel", ["group_name", "name", "data_type", "data", "properties", "dtype"]
72+
)
6973

7074

7175
CHARACTER_REPLACEMENTS = {
@@ -282,7 +286,7 @@ def contains_timing(channel: TdmsChannel) -> bool:
282286
new_channel = ChannelObject(
283287
group=sanitize_string(channel.group_name),
284288
channel=sanitize_string(channel.name),
285-
data=channel.data,
289+
data=channel.raw_data,
286290
properties=channel.properties,
287291
)
288292
valid_channels.append(new_channel)
@@ -407,6 +411,7 @@ def get_time_channels(group: TdmsGroup) -> List[TdmsChannel]:
407411
group_name=updated_group_name,
408412
name=updated_channel_name,
409413
data_type=channel.data_type,
414+
dtype=channel.dtype,
410415
data=data,
411416
properties=channel.properties,
412417
)
@@ -485,12 +490,12 @@ def _create_csv_config(
485490
first_data_column = 2
486491
for i, channel in enumerate(channels):
487492
try:
488-
data_type = TDMS_TO_SIFT_TYPES[channel.data_type].as_human_str(api_format=True)
493+
data_type = NUMPY_TO_SIFT_TYPES[channel.dtype.type].as_human_str(api_format=True)
489494
except KeyError:
490495
data_type = None
491496

492497
if data_type is None:
493-
raise Exception(f"{channel.name} data type not supported: {channel.data_type}")
498+
raise Exception(f"{channel.name} data type not supported: {channel.dtype}")
494499

495500
channel_config = DataColumn(
496501
name=_channel_fqn(name=channel.name, component=channel.group_name)

0 commit comments

Comments
 (0)