1919from executorch .backends .transforms .postpone_permute_below_squeeze_view import (
2020 PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView ,
2121)
22- from executorch .backends .transforms .replace_nop_transpose_or_permute_with_view import (
23- ReplaceNopTransposeOrPermuteWithViewPass ,
24- )
2522from executorch .backends .transforms .remove_permutes_around_elementwise_ops import (
2623 RemovePermutesAroundElementwiseOps ,
2724)
25+ from executorch .backends .transforms .replace_nop_transpose_or_permute_with_view import (
26+ ReplaceNopTransposeOrPermuteWithViewPass ,
27+ )
2828from executorch .exir .dialects ._ops import ops as exir_ops
2929from executorch .exir .pass_base import PassResult
3030from torch .utils import _pytree as pytree
@@ -122,15 +122,12 @@ def test_cascaded_permutes_multiple_users(self) -> None:
122122 permute1 = builder .call_operator (
123123 op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 3 , 1 ])
124124 )
125- # permute2 reverses permute1 => identity
126125 permute2 = builder .call_operator (
127126 op = exir_ops .edge .aten .permute_copy .default , args = (permute1 , [0 , 3 , 1 , 2 ])
128127 )
129- # permute3: different permutation
130128 permute3 = builder .call_operator (
131129 op = exir_ops .edge .aten .permute_copy .default , args = (permute1 , [0 , 2 , 1 , 3 ])
132130 )
133- # permute4 -> permute5: chained
134131 permute4 = builder .call_operator (
135132 op = exir_ops .edge .aten .permute_copy .default , args = (permute1 , [3 , 2 , 0 , 1 ])
136133 )
@@ -151,6 +148,38 @@ def test_cascaded_permutes_multiple_users(self) -> None:
151148 "FuseCascadedTransposeOrPermuteOps" ,
152149 )
153150
151+ def test_permute_view_permute_fuse (self ) -> None :
152+ """permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) should
153+ be replaced with a single view_copy (permutations cancel out)."""
154+ builder = GraphBuilder ()
155+ x_data = torch .randn (1 , 40 , 18 )
156+ x = builder .placeholder ("x" , x_data )
157+ p1 = builder .call_operator (
158+ op = exir_ops .edge .aten .permute_copy .default , args = (x , [0 , 2 , 1 ])
159+ )
160+ v = builder .call_operator (
161+ op = exir_ops .edge .aten .view_copy .default , args = (p1 , [1 , 18 , 1 , 40 ])
162+ )
163+ p2 = builder .call_operator (
164+ op = exir_ops .edge .aten .permute_copy .default , args = (v , [0 , 2 , 3 , 1 ])
165+ )
166+ builder .output ([p2 ])
167+ original = builder .get_graph_module ()
168+ gm_before = copy .deepcopy (original )
169+
170+ p = FuseCascadedTransposeOrPermuteOps ()
171+ result = cast (PassResult , p (original ))
172+ self .assertTrue (result .modified )
173+ gm = result .graph_module
174+
175+ self .assertEqual (count_node (gm , exir_ops .edge .aten .permute_copy .default ), 0 )
176+ self .assertGreaterEqual (
177+ count_node (gm , exir_ops .edge .aten .view_copy .default ), 1
178+ )
179+ validate_numerics (
180+ gm_before , gm , [x_data ], "FuseCascadedAcrossView" ,
181+ )
182+
154183
155184# ──────────────────────────────────────────────────────────────────────
156185# Tests for FuseCascadedViewOps
@@ -250,7 +279,6 @@ def test_permute3_view4_chains(self) -> None:
250279
251280 self .assertEqual (count_node (gm , exir_ops .edge .aten .view_copy .default ), 2 )
252281 self .assertEqual (count_node (gm , exir_ops .edge .aten .permute_copy .default ), 2 )
253- # Verify order: views before permutes
254282 targets = get_compute_nodes (gm )
255283 view_indices = [
256284 i
@@ -350,7 +378,6 @@ def test_negative_not_squeeze_like(self) -> None:
350378 count_node (result .graph_module , exir_ops .edge .aten .permute_copy .default ),
351379 2 ,
352380 )
353- # Order unchanged: view, permute, view, permute
354381 targets = get_compute_nodes (result .graph_module )
355382 self .assertEqual (targets [0 ], exir_ops .edge .aten .view_copy .default )
356383 self .assertEqual (targets [1 ], exir_ops .edge .aten .permute_copy .default )
@@ -383,6 +410,67 @@ def test_replace_nop_transpose_with_view_float(self) -> None:
383410 gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
384411 )
385412
413+ def test_replace_nop_transpose_with_view_int (self ) -> None :
414+ x = torch .randint (low = 0 , high = 100 , size = (2 , 1 , 5 ), dtype = torch .int64 )
415+ gm = single_op_builder (
416+ placeholders = (x ,),
417+ op = exir_ops .edge .aten .transpose_copy .int ,
418+ args = (x , 1 , 0 ),
419+ )
420+ gm_before = copy .deepcopy (gm )
421+
422+ p = ReplaceNopTransposeOrPermuteWithViewPass ()
423+ result = cast (PassResult , p (gm ))
424+ self .assertTrue (result .modified )
425+ gm_after = result .graph_module
426+ self .assertEqual (count_node (gm_after , exir_ops .edge .aten .transpose_copy .int ), 0 )
427+ self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
428+ validate_numerics (
429+ gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
430+ )
431+
432+ def test_replace_nop_permute_5d (self ) -> None :
433+ x = torch .randn (3 , 1 , 3 , 1 , 4 )
434+ gm = single_op_builder (
435+ placeholders = (x ,),
436+ op = exir_ops .edge .aten .permute_copy .default ,
437+ args = (x , [0 , 2 , 4 , 1 , 3 ]),
438+ )
439+ gm_before = copy .deepcopy (gm )
440+
441+ p = ReplaceNopTransposeOrPermuteWithViewPass ()
442+ result = cast (PassResult , p (gm ))
443+ self .assertTrue (result .modified )
444+ gm_after = result .graph_module
445+ self .assertEqual (
446+ count_node (gm_after , exir_ops .edge .aten .permute_copy .default ), 0
447+ )
448+ self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
449+ validate_numerics (
450+ gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
451+ )
452+
453+ def test_replace_nop_permute_3d (self ) -> None :
454+ x = torch .randn (1 , 3 , 4 )
455+ gm = single_op_builder (
456+ placeholders = (x ,),
457+ op = exir_ops .edge .aten .permute_copy .default ,
458+ args = (x , [1 , 2 , 0 ]),
459+ )
460+ gm_before = copy .deepcopy (gm )
461+
462+ p = ReplaceNopTransposeOrPermuteWithViewPass ()
463+ result = cast (PassResult , p (gm ))
464+ self .assertTrue (result .modified )
465+ gm_after = result .graph_module
466+ self .assertEqual (
467+ count_node (gm_after , exir_ops .edge .aten .permute_copy .default ), 0
468+ )
469+ self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
470+ validate_numerics (
471+ gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
472+ )
473+
386474
387475# ──────────────────────────────────────────────────────────────────────
388476# Tests for RemovePermutesAroundElementwiseOps cross-view handling
@@ -458,64 +546,3 @@ def test_4d_permute_squeeze_clamp_3d_permute(self) -> None:
458546 gm_before , result .graph_module , [x_data ],
459547 "4D_permute_squeeze_clamp_3D_permute" ,
460548 )
461-
462- def test_replace_nop_transpose_with_view_int (self ) -> None :
463- x = torch .randint (low = 0 , high = 100 , size = (2 , 1 , 5 ), dtype = torch .int64 )
464- gm = single_op_builder (
465- placeholders = (x ,),
466- op = exir_ops .edge .aten .transpose_copy .int ,
467- args = (x , 1 , 0 ),
468- )
469- gm_before = copy .deepcopy (gm )
470-
471- p = ReplaceNopTransposeOrPermuteWithViewPass ()
472- result = cast (PassResult , p (gm ))
473- self .assertTrue (result .modified )
474- gm_after = result .graph_module
475- self .assertEqual (count_node (gm_after , exir_ops .edge .aten .transpose_copy .int ), 0 )
476- self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
477- validate_numerics (
478- gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
479- )
480-
481- def test_replace_nop_permute_5d (self ) -> None :
482- x = torch .randn (3 , 1 , 3 , 1 , 4 )
483- gm = single_op_builder (
484- placeholders = (x ,),
485- op = exir_ops .edge .aten .permute_copy .default ,
486- args = (x , [0 , 2 , 4 , 1 , 3 ]),
487- )
488- gm_before = copy .deepcopy (gm )
489-
490- p = ReplaceNopTransposeOrPermuteWithViewPass ()
491- result = cast (PassResult , p (gm ))
492- self .assertTrue (result .modified )
493- gm_after = result .graph_module
494- self .assertEqual (
495- count_node (gm_after , exir_ops .edge .aten .permute_copy .default ), 0
496- )
497- self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
498- validate_numerics (
499- gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
500- )
501-
502- def test_replace_nop_permute_3d (self ) -> None :
503- x = torch .randn (1 , 3 , 4 )
504- gm = single_op_builder (
505- placeholders = (x ,),
506- op = exir_ops .edge .aten .permute_copy .default ,
507- args = (x , [1 , 2 , 0 ]),
508- )
509- gm_before = copy .deepcopy (gm )
510-
511- p = ReplaceNopTransposeOrPermuteWithViewPass ()
512- result = cast (PassResult , p (gm ))
513- self .assertTrue (result .modified )
514- gm_after = result .graph_module
515- self .assertEqual (
516- count_node (gm_after , exir_ops .edge .aten .permute_copy .default ), 0
517- )
518- self .assertEqual (count_node (gm_after , exir_ops .edge .aten .view_copy .default ), 1 )
519- validate_numerics (
520- gm_before , gm_after , [x ], "ReplaceNopTransposeOrPermuteWithViewPass"
521- )
0 commit comments