66
77# pyre-strict
88
9+ import dataclasses
910from abc import abstractmethod
1011from dataclasses import dataclass
1112from typing import Callable , List , Optional , override , Set , Type , TypeVar , Union
@@ -28,11 +29,20 @@ def allow_lifetime_and_storage_overlap(opt_level: int) -> bool:
2829 return opt_level >= 2
2930
3031
32+ # A dataclass that bundles feature flags for edge passes.
33+ # When adding a new flag, add a matching bool field to both this class and
34+ # CadencePassAttribute; the pass filter will pick it up automatically.
35+ @dataclass (frozen = True )
36+ class EdgePassesConfig :
37+ use_im2row_transform : bool = False
38+
39+
3140# A dataclass that stores the attributes of an ExportPass.
3241@dataclass (frozen = True )
3342class CadencePassAttribute :
3443 opt_level : Optional [int ] = None
3544 debug_pass : bool = False
45+ use_im2row_transform : bool = False
3646
3747
3848# A dictionary that maps an ExportPass to its attributes.
@@ -58,17 +68,38 @@ def get_all_available_cadence_passes() -> Set[Type[PassBase]]:
5868 return set (ALL_CADENCE_PASSES .keys ())
5969
6070
71+ def _check_feature_flags (
72+ pass_attribute : CadencePassAttribute ,
73+ config : EdgePassesConfig ,
74+ ) -> bool :
75+ """Check all feature flags: a pass is included only if every feature it
76+ requires is enabled in the config. Iterates over EdgePassesConfig fields
77+ so new flags are handled automatically."""
78+ for field in dataclasses .fields (EdgePassesConfig ):
79+ if getattr (pass_attribute , field .name , False ) and not getattr (
80+ config , field .name
81+ ):
82+ return False
83+ return True
84+
85+
6186# Create a new filter to filter out relevant passes from all passes.
6287def create_cadence_pass_filter (
63- opt_level : int , debug : bool = False
88+ opt_level : int ,
89+ debug : bool = False ,
90+ edge_passes_config : Optional [EdgePassesConfig ] = None ,
6491) -> Callable [[Type [PassBase ]], bool ]:
92+ if edge_passes_config is None :
93+ edge_passes_config = EdgePassesConfig ()
94+
6595 def _filter (p : Type [PassBase ]) -> bool :
6696 pass_attribute = get_cadence_pass_attribute (p )
6797 return (
6898 pass_attribute is not None
6999 and pass_attribute .opt_level is not None
70100 and pass_attribute .opt_level <= opt_level
71101 and (not pass_attribute .debug_pass or debug )
102+ and _check_feature_flags (pass_attribute , edge_passes_config )
72103 )
73104
74105 return _filter
0 commit comments