4040 CallbackMapper as CallbackMapperBase ,
4141 CSECachingMapperMixin ,
4242 )
43- from pymbolic .mapper .evaluator import \
44- EvaluationMapper as EvaluationMapperBase
45- from pymbolic .mapper .substitutor import \
46- SubstitutionMapper as SubstitutionMapperBase
47- from pymbolic .mapper .stringifier import \
48- StringifyMapper as StringifyMapperBase
49- from pymbolic .mapper .dependency import \
50- DependencyMapper as DependencyMapperBase
51- from pymbolic .mapper .coefficient import \
52- CoefficientCollector as CoefficientCollectorBase
53- from pymbolic .mapper .unifier import UnidirectionalUnifier \
54- as UnidirectionalUnifierBase
55- from pymbolic .mapper .constant_folder import \
56- ConstantFoldingMapper as ConstantFoldingMapperBase
43+ from pymbolic .mapper .equality import (
44+ EqualityMapper as EqualityMapperBase )
45+ from pymbolic .mapper .evaluator import (
46+ EvaluationMapper as EvaluationMapperBase )
47+ from pymbolic .mapper .substitutor import (
48+ SubstitutionMapper as SubstitutionMapperBase )
49+ from pymbolic .mapper .stringifier import (
50+ StringifyMapper as StringifyMapperBase )
51+ from pymbolic .mapper .dependency import (
52+ DependencyMapper as DependencyMapperBase )
53+ from pymbolic .mapper .coefficient import (
54+ CoefficientCollector as CoefficientCollectorBase )
55+ from pymbolic .mapper .unifier import (
56+ UnidirectionalUnifier as UnidirectionalUnifierBase )
57+ from pymbolic .mapper .constant_folder import (
58+ ConstantFoldingMapper as ConstantFoldingMapperBase )
5759
5860from pymbolic .parser import Parser as ParserBase
5961from loopy .diagnostic import LoopyError
60- from loopy .diagnostic import (ExpressionToAffineConversionError ,
61- UnableToDetermineAccessRangeError )
62+ from loopy .diagnostic import (
63+ ExpressionToAffineConversionError ,
64+ UnableToDetermineAccessRangeError )
6265
6366
6467import islpy as isl
@@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
114117 return expr
115118
116119 def map_array_literal (self , expr , * args , ** kwargs ):
117- return type (expr )(tuple (self .rec (ch , * args , ** kwargs )
118- for ch in expr .children ))
120+ children = [self .rec (ch , * args , ** kwargs ) for ch in expr .children ]
121+ if all (ch is orig for ch , orig in zip (children , expr .children )):
122+ return expr
123+
124+ return type (expr )(tuple (children ))
119125
120126 def map_group_hw_index (self , expr , * args , ** kwargs ):
121127 return expr
@@ -484,6 +490,55 @@ def map_substitution(self, name, rule, arguments):
484490
485491 return self .rec (expr )
486492
493+
494+ class EqualityMapper (EqualityMapperBase ):
495+ def map_loopy_function_identifier (self , expr , other ) -> bool :
496+ return True
497+
498+ def map_reduction (self , expr , other ) -> bool :
499+ return (
500+ expr .operation == other .operation
501+ and expr .allow_simultaneous == other .allow_simultaneous
502+ and self .rec (expr .expr , other .expr )
503+ and all (iname == other_iname
504+ for iname , other_iname in zip (expr .inames , other .inames )))
505+
506+ def map_group_hw_index (self , expr , other ) -> bool :
507+ return expr .axis == other .axis
508+
509+ map_local_hw_index = map_group_hw_index
510+
511+ def map_rule_argument (self , expr , other ) -> bool :
512+ return expr .index == other .index
513+
514+ def map_resolved_function (self , expr , other ) -> bool :
515+ return self .rec (expr .function , other .function )
516+
517+ def map_sub_array_ref (self , expr , other ) -> bool :
518+ return (
519+ len (expr .swept_inames ) == len (other .swept_inames )
520+ and self .rec (expr .subscript , other .subscript )
521+ and all (self .rec (iname , other_iname )
522+ for iname , other_iname in zip (
523+ expr .swept_inames ,
524+ other .swept_inames ))
525+ )
526+
527+ def map_tagged_variable (self , expr , other ) -> bool :
528+ return (
529+ expr .name == other .name
530+ and all (tag == other_tag
531+ for tag , other_tag in zip (expr .tags , other .tags ))
532+ )
533+
534+ def map_type_cast (self , expr , other ) -> bool :
535+ return (
536+ expr .type == other .type
537+ and self .rec (expr .child , other .child ))
538+
539+ def map_fortran_division (self , expr , other ) -> bool :
540+ return self .map_quotient (expr , other )
541+
487542# }}}
488543
489544
@@ -497,15 +552,18 @@ def stringifier(self):
497552 def make_stringifier (self , originating_stringifier = None ):
498553 return StringifyMapper ()
499554
555+ def make_equality_mapper (self ):
556+ return EqualityMapper ()
557+
500558
501559class Literal (LoopyExpressionBase ):
502560 """A literal to be used during code generation.
503561
504562 .. note::
505563
506564 Only used in the output of
507- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
508- similar mappers). Not for use in Loopy source representation.
565+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
566+ (and similar mappers). Not for use in :mod:`loopy` source representation.
509567 """
510568
511569 def __init__ (self , s ):
@@ -525,8 +583,8 @@ class ArrayLiteral(LoopyExpressionBase):
525583 .. note::
526584
527585 Only used in the output of
528- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
529- similar mappers). Not for use in Loopy source representation.
586+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
587+ (and similar mappers). Not for use in :mod:`loopy` source representation.
530588 """
531589
532590 def __init__ (self , children ):
@@ -555,8 +613,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
555613 .. note::
556614
557615 Only used in the output of
558- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
559- similar mappers). Not for use in Loopy source representation.
616+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
617+ (and similar mappers). Not for use in :mod:`loopy` source representation.
560618 """
561619 mapper_method = "map_group_hw_index"
562620
@@ -566,8 +624,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
566624 .. note::
567625
568626 Only used in the output of
569- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
570- similar mappers). Not for use in Loopy source representation.
627+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
628+ similar mappers). Not for use in :mod:`loopy` source representation.
571629 """
572630 mapper_method = "map_local_hw_index"
573631
@@ -774,12 +832,6 @@ def __getinitargs__(self):
774832 def get_hash (self ):
775833 return hash ((self .__class__ , self .operation , self .inames , self .expr ))
776834
777- def is_equal (self , other ):
778- return (other .__class__ == self .__class__
779- and other .operation == self .operation
780- and other .inames == self .inames
781- and other .expr == self .expr )
782-
783835 @property
784836 def is_tuple_typed (self ):
785837 return self .operation .arg_count > 1
@@ -977,14 +1029,6 @@ def __getinitargs__(self):
9771029 def get_hash (self ):
9781030 return hash ((self .__class__ , self .swept_inames , self .subscript ))
9791031
980- def is_equal (self , other ):
981- """
982- Returns *True* iff the sub-array refs have identical expressions.
983- """
984- return (other .__class__ == self .__class__
985- and other .subscript == self .subscript
986- and other .swept_inames == self .swept_inames )
987-
9881032 def make_stringifier (self , originating_stringifier = None ):
9891033 return StringifyMapper ()
9901034
0 commit comments