@@ -30,32 +30,18 @@ class Map(object):
3030
3131 __slots__ = ("values" , "offset" , "interior_horizontal" ,
3232 "variable" , "unroll" , "layer_bounds" ,
33- "prefetch" , "permutation " )
33+ "prefetch" , "_pmap_count " )
3434
3535 def __init__ (self , map_ , interior_horizontal , layer_bounds ,
36- values = None , offset = None , unroll = False ):
36+ offset = None , unroll = False ):
3737 self .variable = map_ .iterset ._extruded and not map_ .iterset .constant_layers
3838 self .unroll = unroll
3939 self .layer_bounds = layer_bounds
4040 self .interior_horizontal = interior_horizontal
4141 self .prefetch = {}
42- if values is not None :
43- raise RuntimeError
44- self .values = values
45- if map_ .offset is not None :
46- assert offset is not None
47- self .offset = offset
48- return
49-
5042 offset = map_ .offset
5143 shape = (None , ) + map_ .shape [1 :]
5244 values = Argument (shape , dtype = map_ .dtype , pfx = "map" )
53- if isinstance (map_ , PermutedMap ):
54- self .permutation = NamedLiteral (map_ .permutation , parent = values , suffix = "permutation" )
55- if offset is not None :
56- offset = offset [map_ .permutation ]
57- else :
58- self .permutation = None
5945 if offset is not None :
6046 if len (set (map_ .offset )) == 1 :
6147 offset = Literal (offset [0 ], casting = True )
@@ -64,6 +50,7 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
6450
6551 self .values = values
6652 self .offset = offset
53+ self ._pmap_count = itertools .count ()
6754
6855 @property
6956 def shape (self ):
@@ -73,7 +60,7 @@ def shape(self):
7360 def dtype (self ):
7461 return self .values .dtype
7562
76- def indexed (self , multiindex , layer = None ):
63+ def indexed (self , multiindex , layer = None , permute = lambda x : x ):
7764 n , i , f = multiindex
7865 if layer is not None and self .offset is not None :
7966 # For extruded mesh, prefetch the indirections for each map, so that they don't
@@ -82,10 +69,7 @@ def indexed(self, multiindex, layer=None):
8269 base_key = None
8370 if base_key not in self .prefetch :
8471 j = Index ()
85- if self .permutation is None :
86- base = Indexed (self .values , (n , j ))
87- else :
88- base = Indexed (self .values , (n , Indexed (self .permutation , (j ,))))
72+ base = Indexed (self .values , (n , permute (j )))
8973 self .prefetch [base_key ] = Materialise (PackInst (), base , MultiIndex (j ))
9074
9175 base = self .prefetch [base_key ]
@@ -112,26 +96,58 @@ def indexed(self, multiindex, layer=None):
11296 return Indexed (self .prefetch [key ], (f , i )), (f , i )
11397 else :
11498 assert f .extent == 1 or f .extent is None
115- if self .permutation is None :
116- base = Indexed (self .values , (n , i ))
117- else :
118- base = Indexed (self .values , (n , Indexed (self .permutation , (i ,))))
99+ base = Indexed (self .values , (n , permute (i )))
119100 return base , (f , i )
120101
121- def indexed_vector (self , n , shape , layer = None ):
102+ def indexed_vector (self , n , shape , layer = None , permute = lambda x : x ):
122103 shape = self .shape [1 :] + shape
123104 if self .interior_horizontal :
124105 shape = (2 , ) + shape
125106 else :
126107 shape = (1 , ) + shape
127108 f , i , j = (Index (e ) for e in shape )
128- base , (f , i ) = self .indexed ((n , i , f ), layer = layer )
109+ base , (f , i ) = self .indexed ((n , i , f ), layer = layer , permute = permute )
129110 init = Sum (Product (base , Literal (numpy .int32 (j .extent ))), j )
130111 pack = Materialise (PackInst (), init , MultiIndex (f , i , j ))
131112 multiindex = tuple (Index (e ) for e in pack .shape )
132113 return Indexed (pack , multiindex ), multiindex
133114
134115
116+ class PMap (Map ):
117+ __slots__ = ("permutation" ,)
118+
119+ def __init__ (self , map_ , permutation ):
120+ # Copy over properties
121+ self .variable = map_ .variable
122+ self .unroll = map_ .unroll
123+ self .layer_bounds = map_ .layer_bounds
124+ self .interior_horizontal = map_ .interior_horizontal
125+ self .prefetch = {}
126+ self .values = map_ .values
127+ self .offset = map_ .offset
128+ offset = map_ .offset
129+ # TODO: this is a hack, rep2loopy should be in charge of
130+ # generating all names!
131+ count = next (map_ ._pmap_count )
132+ if offset is not None :
133+ if offset .shape :
134+ # Have a named literal
135+ offset = offset .value [permutation ]
136+ offset = NamedLiteral (offset , parent = self .values , suffix = f"permutation{ count } _offset" )
137+ else :
138+ offset = map_ .offset
139+ self .offset = offset
140+ self .permutation = NamedLiteral (permutation , parent = self .values , suffix = f"permutation{ count } " )
141+
142+ def indexed (self , multiindex , layer = None ):
143+ permute = lambda x : Indexed (self .permutation , (x ,))
144+ return super ().indexed (multiindex , layer = layer , permute = permute )
145+
146+ def indexed_vector (self , n , shape , layer = None ):
147+ permute = lambda x : Indexed (self .permutation , (x ,))
148+ return super ().indexed_vector (n , shape , layer = layer , permute = permute )
149+
150+
135151class Pack (metaclass = ABCMeta ):
136152
137153 def pick_loop_indices (self , loop_index , layer_index = None , entity_index = None ):
@@ -818,9 +834,13 @@ def map_(self, map_, unroll=False):
818834 try :
819835 return self .maps [key ]
820836 except KeyError :
821- map_ = Map (map_ , interior_horizontal ,
822- (self .bottom_layer , self .top_layer ),
823- unroll = unroll )
837+ if isinstance (map_ , PermutedMap ):
838+ imap = self .map_ (map_ .map_ , unroll = unroll )
839+ map_ = PMap (imap , map_ .permutation )
840+ else :
841+ map_ = Map (map_ , interior_horizontal ,
842+ (self .bottom_layer , self .top_layer ),
843+ unroll = unroll )
824844 self .maps [key ] = map_
825845 return map_
826846
@@ -854,7 +874,11 @@ def wrapper_args(self):
854874 args .extend (self .arguments )
855875 # maps are refcounted
856876 for map_ in self .maps .values ():
857- args .append (map_ .values )
877+ # But we don't need to emit stuff for PMaps because they
878+ # are a Map (already seen + a permutation [encoded in the
879+ # indexing]).
880+ if not isinstance (map_ , PMap ):
881+ args .append (map_ .values )
858882 return tuple (args )
859883
860884 def kernel_call (self ):
0 commit comments