4343 CSECachingMapperMixin ,
4444 )
4545import immutables
46+ from pymbolic .mapper .equality import (
47+ EqualityMapper as EqualityMapperBase )
4648from pymbolic .mapper .evaluator import \
4749 CachedEvaluationMapper as EvaluationMapperBase
4850from pymbolic .mapper .substitutor import \
6062
6163from pymbolic .parser import Parser as ParserBase
6264from loopy .diagnostic import LoopyError
63- from loopy .diagnostic import (ExpressionToAffineConversionError ,
64- UnableToDetermineAccessRangeError )
65+ from loopy .diagnostic import (
66+ ExpressionToAffineConversionError ,
67+ UnableToDetermineAccessRangeError )
6568
6669
6770import islpy as isl
@@ -117,8 +120,11 @@ def map_literal(self, expr, *args, **kwargs):
117120 return expr
118121
119122 def map_array_literal (self , expr , * args , ** kwargs ):
120- return type (expr )(tuple (self .rec (ch , * args , ** kwargs )
121- for ch in expr .children ))
123+ children = [self .rec (ch , * args , ** kwargs ) for ch in expr .children ]
124+ if all (ch is orig for ch , orig in zip (children , expr .children )):
125+ return expr
126+
127+ return type (expr )(tuple (children ))
122128
123129 def map_group_hw_index (self , expr , * args , ** kwargs ):
124130 return expr
@@ -501,6 +507,60 @@ def map_substitution(self, name, rule, arguments):
501507
502508 return self .rec (expr )
503509
510+
511+ class EqualityMapper (EqualityMapperBase ):
512+ def map_loopy_function_identifier (self , expr , other ) -> bool :
513+ return True
514+
515+ def map_reduction (self , expr , other ) -> bool :
516+ return (
517+ expr .operation == other .operation
518+ and expr .allow_simultaneous == other .allow_simultaneous
519+ and self .rec (expr .expr , other .expr )
520+ and all (iname == other_iname
521+ for iname , other_iname in zip (expr .inames , other .inames )))
522+
523+ def map_group_hw_index (self , expr , other ) -> bool :
524+ return expr .axis == other .axis
525+
526+ map_local_hw_index = map_group_hw_index
527+
528+ def map_linear_subscript (self , expr , other ) -> bool :
529+ return (
530+ self .rec (expr .index , other .index )
531+ and self .rec (expr .aggregate , other .aggregate ))
532+
533+ def map_rule_argument (self , expr , other ) -> bool :
534+ return expr .index == other .index
535+
536+ def map_resolved_function (self , expr , other ) -> bool :
537+ return self .rec (expr .function , other .function )
538+
539+ def map_sub_array_ref (self , expr , other ) -> bool :
540+ return (
541+ len (expr .swept_inames ) == len (other .swept_inames )
542+ and self .rec (expr .subscript , other .subscript )
543+ and all (self .rec (iname , other_iname )
544+ for iname , other_iname in zip (
545+ expr .swept_inames ,
546+ other .swept_inames ))
547+ )
548+
549+ def map_tagged_variable (self , expr , other ) -> bool :
550+ return (
551+ expr .name == other .name
552+ and all (tag == other_tag
553+ for tag , other_tag in zip (expr .tags , other .tags ))
554+ )
555+
556+ def map_type_cast (self , expr , other ) -> bool :
557+ return (
558+ expr .type == other .type
559+ and self .rec (expr .child , other .child ))
560+
561+ def map_fortran_division (self , expr , other ) -> bool :
562+ return self .map_quotient (expr , other )
563+
504564# }}}
505565
506566
@@ -514,15 +574,18 @@ def stringifier(self):
514574 def make_stringifier (self , originating_stringifier = None ):
515575 return StringifyMapper ()
516576
577+ def make_equality_mapper (self ):
578+ return EqualityMapper ()
579+
517580
518581class Literal (LoopyExpressionBase ):
519582 """A literal to be used during code generation.
520583
521584 .. note::
522585
523586 Only used in the output of
524- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525- similar mappers). Not for use in Loopy source representation.
587+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
588+ (and similar mappers). Not for use in :mod:`loopy` source representation.
526589 """
527590
528591 def __init__ (self , s ):
@@ -542,8 +605,8 @@ class ArrayLiteral(LoopyExpressionBase):
542605 .. note::
543606
544607 Only used in the output of
545- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546- similar mappers). Not for use in Loopy source representation.
608+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
609+ (and similar mappers). Not for use in :mod:`loopy` source representation.
547610 """
548611
549612 def __init__ (self , children ):
@@ -572,8 +635,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572635 .. note::
573636
574637 Only used in the output of
575- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576- similar mappers). Not for use in Loopy source representation.
638+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
639+ (and similar mappers). Not for use in :mod:`loopy` source representation.
577640 """
578641 mapper_method = "map_group_hw_index"
579642
@@ -583,8 +646,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583646 .. note::
584647
585648 Only used in the output of
586- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587- similar mappers). Not for use in Loopy source representation.
649+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
650+ similar mappers). Not for use in :mod:`loopy` source representation.
588651 """
589652 mapper_method = "map_local_hw_index"
590653
@@ -791,12 +854,6 @@ def __getinitargs__(self):
791854 def get_hash (self ):
792855 return hash ((self .__class__ , self .operation , self .inames , self .expr ))
793856
794- def is_equal (self , other ):
795- return (other .__class__ == self .__class__
796- and other .operation == self .operation
797- and other .inames == self .inames
798- and other .expr == self .expr )
799-
800857 @property
801858 def is_tuple_typed (self ):
802859 return self .operation .arg_count > 1
@@ -994,14 +1051,6 @@ def __getinitargs__(self):
9941051 def get_hash (self ):
9951052 return hash ((self .__class__ , self .swept_inames , self .subscript ))
9961053
997- def is_equal (self , other ):
998- """
999- Returns *True* iff the sub-array refs have identical expressions.
1000- """
1001- return (other .__class__ == self .__class__
1002- and other .subscript == self .subscript
1003- and other .swept_inames == self .swept_inames )
1004-
10051054 def make_stringifier (self , originating_stringifier = None ):
10061055 return StringifyMapper ()
10071056
0 commit comments