@@ -315,29 +315,38 @@ def test_fix_ungrouped_dims(
315315@mark .parametrize (
316316 [
317317 "physical_shape" ,
318+ "strides" ,
318319 "pdim" ,
319320 "new_pdim_shape" ,
320321 "expected_physical_shape" ,
321- "expected_new_encoding " ,
322+ "expected_strides " ,
322323 ],
323324 [
324- ([4 ], 0 , [4 ], [4 ], [[0 ]]), # trivial
325- ([4 ], 0 , [2 , 2 ], [2 , 2 ], [[0 , 1 ]]),
326- ([3 , 4 , 5 ], 1 , [2 , 1 , 1 , 2 ], [3 , 2 , 1 , 1 , 2 , 5 ], [[0 ], [1 , 2 , 3 , 4 ], [5 ]]),
325+ ([4 ], tensor ([[1 ], [2 ]]), 0 , [4 ], [4 ], tensor ([[1 ], [2 ]])), # trivial
326+ ([4 ], tensor ([[1 ], [2 ]]), 0 , [2 , 2 ], [2 , 2 ], tensor ([[2 , 1 ], [4 , 2 ]])),
327+ (
328+ [3 , 4 , 5 ],
329+ tensor ([[1 , 2 , 0 ], [1 , 0 , 1 ], [0 , 1 , 1 ]]),
330+ 1 ,
331+ [2 , 1 , 1 , 2 ],
332+ [3 , 2 , 1 , 1 , 2 , 5 ],
333+ tensor ([[1 , 4 , 4 , 4 , 2 , 0 ], [1 , 0 , 0 , 0 , 0 , 1 ], [0 , 2 , 2 , 2 , 1 , 1 ]]),
334+ ),
327335 ],
328336)
329337def test_unsquash_pdim (
330338 physical_shape : list [int ],
339+ strides : Tensor ,
331340 pdim : int ,
332341 new_pdim_shape : list [int ],
333342 expected_physical_shape : list [int ],
334- expected_new_encoding : list [ list [ int ]] ,
343+ expected_strides : Tensor ,
335344):
336345 physical = randn_ (physical_shape )
337- new_physical , new_encoding = unsquash_pdim (physical , pdim , new_pdim_shape )
346+ new_physical , new_strides = unsquash_pdim (physical , strides , pdim , new_pdim_shape )
338347
339348 assert list (new_physical .shape ) == expected_physical_shape
340- assert new_encoding == expected_new_encoding
349+ assert torch . equal ( new_strides , expected_strides )
341350
342351
343352@mark .parametrize (
0 commit comments