@@ -278,3 +278,63 @@ def test_cat_bool(labels):
278278 assert selected .stored_labels [1 ]["dof" ] == [f"s{ i } " for i in range (10 )]
279279 if isinstance (labels , dict ):
280280 assert selected .stored_labels [0 ]["dof" ] == ["a" , "b" ]
281+
282+
283+ def test_getitem_int ():
284+ data = torch .rand (20 , 3 )
285+ labels = {1 : {"name" : 1 , "dof" : ["x" , "y" , "z" ]}}
286+ lt = LabelTensor (data , labels )
287+ new = lt [0 , 0 ]
288+ assert new .ndim == 1
289+ assert new .shape [0 ] == 1
290+ assert torch .all (torch .isclose (data [0 , 0 ], new ))
291+
292+ data = torch .rand (20 , 3 , 2 )
293+ labels = {
294+ 1 : {"name" : 1 , "dof" : ["x" , "y" , "z" ]},
295+ 2 : {"name" : 2 , "dof" : ["a" , "b" ]},
296+ }
297+ lt = LabelTensor (data , labels )
298+ new = lt [0 , 0 , 0 ]
299+ assert new .ndim == 2
300+ assert new .shape [0 ] == 1
301+ assert new .shape [1 ] == 1
302+ assert torch .all (torch .isclose (data [0 , 0 , 0 ], new ))
303+ assert new .stored_labels [0 ]["dof" ] == ["x" ]
304+ assert new .stored_labels [1 ]["dof" ] == ["a" ]
305+
306+ new = lt [0 , 0 , :]
307+ assert new .ndim == 2
308+ assert new .shape [0 ] == 1
309+ assert new .shape [1 ] == 2
310+ assert torch .all (torch .isclose (data [0 , 0 , :], new ))
311+ assert new .stored_labels [0 ]["dof" ] == ["x" ]
312+ assert new .stored_labels [1 ]["dof" ] == ["a" , "b" ]
313+
314+ new = lt [0 , :, 1 ]
315+ assert new .ndim == 2
316+ assert new .shape [0 ] == 3
317+ assert new .shape [1 ] == 1
318+ assert torch .all (torch .isclose (data [0 , :, 1 ], new .squeeze ()))
319+ assert new .stored_labels [0 ]["dof" ] == ["x" , "y" , "z" ]
320+ assert new .stored_labels [1 ]["dof" ] == ["b" ]
321+
322+ labels .pop (2 )
323+ lt = LabelTensor (data , labels )
324+ new = lt [0 , 0 , 0 ]
325+ assert new .ndim == 1
326+ assert new .shape [0 ] == 1
327+ assert new .stored_labels [0 ]["dof" ] == ["x" ]
328+
329+ new = lt [:, 0 , 0 ]
330+ assert new .ndim == 2
331+ assert new .shape [0 ] == 20
332+ assert new .shape [1 ] == 1
333+ assert new .stored_labels [1 ]["dof" ] == ["x" ]
334+
335+ new = lt [:, 0 , :]
336+ assert new .ndim == 3
337+ assert new .shape [0 ] == 20
338+ assert new .shape [1 ] == 1
339+ assert new .shape [2 ] == 2
340+ assert new .stored_labels [1 ]["dof" ] == ["x" ]
0 commit comments