1919 Conditional , DummyEq , Expression , FindNodes , FindSymbols , Iteration ,
2020 ParallelIteration , retrieve_iteration_tree
2121)
22- from devito .passes .clusters .aliases import collect
22+ from devito .passes .clusters .aliases import AliasKey , collect
2323from devito .passes .clusters .factorization import collect_nested
2424from devito .passes .iet .parpragma import VExpanded
2525from devito .symbolics import ( # noqa
@@ -423,8 +423,9 @@ def test_collection(self, exprs, expected):
423423
424424 extracted = {i .rhs : i .lhs for i in exprs }
425425 ispace = exprs [0 ].ispace
426+ meta = AliasKey (ispace , None , None , None , None )
426427
427- aliases = collect (extracted , ispace , False )
428+ aliases = collect (extracted , meta , False )
428429 aliases .filter (lambda a : a .score > 0 )
429430
430431 assert len (aliases ) == len (expected )
@@ -2553,15 +2554,15 @@ def test_invariants_with_conditional(self):
25532554
25542555 op = Operator (eqn , opt = 'advanced' )
25552556
2556- assert_structure (op , ['t' , 't,fd' , 't,fd ,x,y' ], 't,fd,x,y' )
2557+ assert_structure (op , ['t' , 't,fd,x,y' ], 't,fd,x,y' )
25572558 # Make sure it compiles
25582559 _ = op .cfunction
25592560
25602561 # Check hoisting for time invariant
25612562 eqn = Eq (u , u - (cos (time_sub * factor * f ) * sin (g ) * uf ))
25622563
25632564 op = Operator (eqn , opt = 'advanced' )
2564- assert_structure (op , ['x,y' , 't' , 't,fd' , 't,fd ,x,y' ], 'x,y,t,fd,x,y' )
2565+ assert_structure (op , ['x,y' , 't' , 't,fd,x,y' ], 'x,y,t,fd,x,y' )
25652566 # Make sure it compiles
25662567 _ = op .cfunction
25672568
@@ -2705,10 +2706,9 @@ def test_split_cond(self):
27052706
27062707 cond = FindNodes (Conditional ).visit (op )
27072708 assert len (cond ) == 3
2708- # Each guard should have its own alias for cos(time)
2709- assert 'float r0 = cos(time);' in str (body0 (op ))
2709+ # No aliases in this case due to guards
27102710 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2711- assert len (scalars ) == 2
2711+ assert len (scalars ) == 0
27122712
27132713 def test_split_cond_multi_alias (self ):
27142714 grid = Grid ((11 , 11 ))
@@ -2728,11 +2728,9 @@ def test_split_cond_multi_alias(self):
27282728
27292729 cond = FindNodes (Conditional ).visit (op )
27302730 assert len (cond ) == 3
2731- # Each guard should have its own aliases for cos(time) and sin(time)
2732- assert 'const float r0 = sin(time) + cos(time)' in str (body0 (op ))
2733- assert 'const float r1 = cos(time);' in str (body0 (op ))
2731+ # No aliases in this case due to guards
27342732 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2735- assert len (scalars ) == 3
2733+ assert len (scalars ) == 0
27362734
27372735 def test_multi_cond_no_split (self ):
27382736 grid = Grid ((11 , 11 ))
@@ -2758,7 +2756,7 @@ def test_multi_cond_no_split(self):
27582756 )
27592757
27602758 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2761- assert len (scalars ) == 3
2759+ assert len (scalars ) == 0
27622760
27632761 def test_alias_with_conditional (self ):
27642762 grid = Grid ((11 , 11 ))
@@ -2779,9 +2777,9 @@ def test_alias_with_conditional(self):
27792777 cond = FindNodes (Conditional ).visit (op )
27802778 assert len (cond ) == 3
27812779
2782- # Each guard should have its own alias for cos(time/ctf)
2780+ # No aliases in this case due to guards
27832781 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2784- assert len (scalars ) == 2
2782+ assert len (scalars ) == 0
27852783
27862784 def test_scalar_alias_interp (self ):
27872785 grid = Grid (shape = (11 , 11 ))
@@ -2825,9 +2823,9 @@ def test_scalar_with_cond_access(self):
28252823 cond = FindNodes (Conditional ).visit (op )
28262824 assert len (cond ) == 3
28272825
2828- # # Each guard should have its own alias for cos/sin(f1[time-2])
2826+ # The guards prevent some aliases from being hoisted out
28292827 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2830- assert len (scalars ) == 3
2828+ assert len (scalars ) == 0
28312829
28322830 assert_structure (
28332831 op ,
@@ -2855,21 +2853,19 @@ def test_scalar_with_cond_tinvariant(self):
28552853
28562854 cond = FindNodes (Conditional ).visit (op )
28572855 assert len (cond ) == 1
2858- # One for each 1/dt 1/dt**2
2856+ # One for 1/dt, while 1/dt**2 ain't hoisted out due to the guard
28592857 scalars = [i for i in FindSymbols ().visit (op ) if isinstance (i , Temp )]
2860- assert len (scalars ) == 2
2858+ assert len (scalars ) == 1
28612859
28622860 assert_structure (
28632861 op ,
28642862 ['t,x,y' , 't' , 't,x,y' ],
28652863 'txyxy'
28662864 )
28672865
2868- # Both aliases should be hoisted outside the time loop
2866+ # The 1/dt alias should be hoisted outside the time loop
28692867 assert str (body0 (op ).body [0 ]) == 'const float r0 = 1.0F/dt;'
28702868 assert not body0 (op ).body [0 ].ispace
2871- assert str (body0 (op ).body [1 ]) == 'const float r1 = 1.0F/(dt*dt);'
2872- assert not body0 (op ).body [1 ].ispace
28732869
28742870
28752871class TestIsoAcoustic :
0 commit comments