Skip to content

Commit 1b0dd19

Browse files
committed
Add tests
1 parent 0606ff3 commit 1b0dd19

1 file changed

Lines changed: 60 additions & 0 deletions

File tree

tests/test_label_tensor/test_label_tensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,63 @@ def test_cat_bool(labels):
278278
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
279279
if isinstance(labels, dict):
280280
assert selected.stored_labels[0]["dof"] == ["a", "b"]
281+
282+
283+
def test_getitem_int():
284+
data = torch.rand(20, 3)
285+
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
286+
lt = LabelTensor(data, labels)
287+
new = lt[0, 0]
288+
assert new.ndim == 1
289+
assert new.shape[0] == 1
290+
assert torch.all(torch.isclose(data[0, 0], new))
291+
292+
data = torch.rand(20, 3, 2)
293+
labels = {
294+
1: {"name": 1, "dof": ["x", "y", "z"]},
295+
2: {"name": 2, "dof": ["a", "b"]},
296+
}
297+
lt = LabelTensor(data, labels)
298+
new = lt[0, 0, 0]
299+
assert new.ndim == 2
300+
assert new.shape[0] == 1
301+
assert new.shape[1] == 1
302+
assert torch.all(torch.isclose(data[0, 0, 0], new))
303+
assert new.stored_labels[0]["dof"] == ["x"]
304+
assert new.stored_labels[1]["dof"] == ["a"]
305+
306+
new = lt[0, 0, :]
307+
assert new.ndim == 2
308+
assert new.shape[0] == 1
309+
assert new.shape[1] == 2
310+
assert torch.all(torch.isclose(data[0, 0, :], new))
311+
assert new.stored_labels[0]["dof"] == ["x"]
312+
assert new.stored_labels[1]["dof"] == ["a", "b"]
313+
314+
new = lt[0, :, 1]
315+
assert new.ndim == 2
316+
assert new.shape[0] == 3
317+
assert new.shape[1] == 1
318+
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
319+
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
320+
assert new.stored_labels[1]["dof"] == ["b"]
321+
322+
labels.pop(2)
323+
lt = LabelTensor(data, labels)
324+
new = lt[0, 0, 0]
325+
assert new.ndim == 1
326+
assert new.shape[0] == 1
327+
assert new.stored_labels[0]["dof"] == ["x"]
328+
329+
new = lt[:, 0, 0]
330+
assert new.ndim == 2
331+
assert new.shape[0] == 20
332+
assert new.shape[1] == 1
333+
assert new.stored_labels[1]["dof"] == ["x"]
334+
335+
new = lt[:, 0, :]
336+
assert new.ndim == 3
337+
assert new.shape[0] == 20
338+
assert new.shape[1] == 1
339+
assert new.shape[2] == 2
340+
assert new.stored_labels[1]["dof"] == ["x"]

0 commit comments

Comments
 (0)