114114# {{{ mappers with support for loopy-specific primitives
115115
116116class IdentityMapperMixin :
117+ def map_with_tag (self , expr , * args , ** kwargs ):
118+ new_expr = self .rec (expr .expr , * args , ** kwargs )
119+ return WithTag (expr .tags , new_expr )
120+
117121 def map_literal (self , expr , * args , ** kwargs ):
118122 return expr
119123
@@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):
207211
208212
209213class WalkMapperMixin :
214+ def map_with_tag (self , expr , * args , ** kwargs ):
215+ if not self .visit (expr , * args , ** kwargs ):
216+ return
217+
218+ self .rec (expr .expr , * args , ** kwargs )
219+
210220 def map_literal (self , expr , * args , ** kwargs ):
211221 self .visit (expr , * args , ** kwargs )
212222
@@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):
273283
274284
275285class CombineMapper (CombineMapperBase ):
286+ def map_with_tag (self , expr , * args , ** kwargs ):
287+ return self .rec (expr .expr , * args , ** kwargs )
288+
276289 def map_reduction (self , expr , * args , ** kwargs ):
277290 return self .rec (expr .expr , * args , ** kwargs )
278291
@@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,
298311
299312
300313class StringifyMapper (StringifyMapperBase ):
314+ def map_with_tag (self , expr , * args ):
315+ from pymbolic .mapper .stringifier import PREC_NONE
316+ return f"WithTag({ expr .tags } , { self .rec (expr .expr , PREC_NONE )} "
317+
301318 def map_literal (self , expr , * args ):
302319 return expr .s
303320
@@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
440457 def map_loopy_function_identifier (self , expr , * args , ** kwargs ):
441458 return set ()
442459
460+ def map_with_tag (self , expr , * args , ** kwargs ):
461+ deps = self .rec (expr .expr , * args , ** kwargs )
462+ return deps
463+
443464 def map_sub_array_ref (self , expr , * args , ** kwargs ):
444465 deps = self .rec (expr .subscript , * args , ** kwargs )
445466 return deps - set (expr .swept_inames )
@@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
712733 mapper_method = intern ("map_tagged_variable" )
713734
714735
736+ class WithTag (LoopyExpressionBase ):
737+ """
738+ Represents a frozenset of tags attached to an :attr:`expr`.
739+ """
740+
741+ init_arg_names = ("tags" , "expr" )
742+
743+ def __init__ (self , tags , expr ):
744+ self .tags = tags
745+ self .expr = expr
746+
747+ def __getinitargs__ (self ):
748+ return (self .tags , self .expr )
749+
750+ def get_hash (self ):
751+ return hash ((self .__class__ , self .tags , self .expr ))
752+
753+ def is_equal (self , other ):
754+ return (other .__class__ == self .__class__
755+ and other .tags == self .tags
756+ and other .expr == self .expr )
757+
758+ mapper_method = intern ("map_with_tag" )
759+
760+
715761class Reduction (LoopyExpressionBase ):
716762 """
717763 Represents a reduction operation on :attr:`expr` across :attr:`inames`.
0 commit comments