@@ -62,81 +62,65 @@ def log_total_time(self):
6262 logger .info (f"============================\n " )
6363
6464
65+ def _apply_edp_columns (df : pd .DataFrame , metrics : Metrics ) -> pd .DataFrame :
66+ if not (metrics & Metrics .ENERGY_DELAY_PRODUCT ):
67+ return df
68+ if not (metrics & Metrics .ENERGY ):
69+ del df ["Total<SEP>energy" ]
70+ if not (metrics & Metrics .LATENCY ):
71+ del df ["Total<SEP>latency" ]
72+ return df
73+
74+
6575class OptimalityThresholder :
6676 def __init__ (
6777 self ,
6878 prev_solutions : Mappings ,
6979 _pmapping_row_filter_function : Callable [[pd .DataFrame ], np .ndarray ],
70- aggregator : str ,
7180 print_progress : bool ,
81+ metrics : Metrics ,
7282 ):
73- compare_to = prev_solutions .data
83+ self .metrics = metrics
84+ compare_to = _apply_edp_columns (prev_solutions .data .copy (), metrics )
7485 compare_cols = [c for c in compare_to .columns if col_used_in_pareto (c )]
7586 self ._pmapping_row_filter_function = _pmapping_row_filter_function
76- self .aggregator = aggregator
77-
78- if self .aggregator in ("prod" , "sum" ):
79- objective_cols = [c for c in compare_cols if is_objective_col (c )]
80- self ._agg_cols = objective_cols
81- if objective_cols :
82- values = np .column_stack ([compare_to [c ].values for c in objective_cols ])
83- if self .aggregator == "prod" :
84- agg = np .prod (values , axis = 1 )
85- else :
86- agg = np .sum (values , axis = 1 )
87- self ._agg_threshold = agg .min ()
88- else :
89- self ._agg_threshold = float ("inf" )
90- if print_progress :
91- label = "product" if self .aggregator == "prod" else "sum"
92- print (
93- f"Filtering out pmappings with { label } > "
94- f"{ self ._agg_threshold :.2e} "
95- )
96- else : # "any"
97- compare_to = compare_to .sort_values (by = compare_cols , ascending = False )
9887
99- if len (compare_to ) > 10 :
100- chosen_indices = np .round (np .linspace (0 , len (compare_to ) - 1 , 10 ))
101- else :
102- chosen_indices = np .round (np .arange (len (compare_to )))
88+ compare_to = compare_to .sort_values (by = compare_cols , ascending = False )
89+
90+ if len (compare_to ) > 10 :
91+ chosen_indices = np .round (np .linspace (0 , len (compare_to ) - 1 , 10 ))
92+ else :
93+ chosen_indices = np .round (np .arange (len (compare_to )))
94+
95+ self .compare_to : list [dict [str , float ]] = []
96+ if print_progress :
97+ print (f"Filtering out pmappings worse than the following:" )
10398
104- self .compare_to : list [dict [str , float ]] = []
99+ for i in chosen_indices .astype (int ):
100+ self .compare_to .append ({c : compare_to [c ].iloc [i ] for c in compare_cols })
105101 if print_progress :
106- print (f"Filtering out pmappings worse than the following:" )
107-
108- for i in chosen_indices .astype (int ):
109- self .compare_to .append ({c : compare_to [c ].iloc [i ] for c in compare_cols })
110- if print_progress :
111- print (
112- "\t "
113- + " " .join (
114- f"{ k } ={ float (v ):.2e} "
115- for k , v in self .compare_to [- 1 ].items ()
116- )
102+ print (
103+ "\t "
104+ + " " .join (
105+ f"{ k } ={ float (v ):.2e} " for k , v in self .compare_to [- 1 ].items ()
117106 )
107+ )
118108
119109 def __call__ (self , mapping : pd .DataFrame ) -> bool :
120110 nondominated_by_all = np .ones (len (mapping ), dtype = bool )
121111
122- if self .aggregator in ("prod" , "sum" ):
123- cols_present = [c for c in self ._agg_cols if c in mapping .columns ]
124- if cols_present :
125- values = np .column_stack ([mapping [c ].values for c in cols_present ])
126- if self .aggregator == "prod" :
127- agg = np .prod (values , axis = 1 )
112+ edp_mapping = _apply_edp_columns (
113+ mapping .copy (), self .metrics , return_only_objectives = True
114+ )
115+
116+ for c in self .compare_to :
117+ nondominated = np .zeros (len (edp_mapping ), dtype = bool )
118+ for k , v in c .items ():
119+ if k not in edp_mapping .columns :
120+ nondominated |= True
128121 else :
129- agg = np .sum (values , axis = 1 )
130- nondominated_by_all = agg <= self ._agg_threshold
131- else : # "any"
132- for c in self .compare_to :
133- nondominated = np .zeros (len (mapping ), dtype = bool )
134- for k , v in c .items ():
135- if k not in mapping .columns :
136- nondominated |= True
137- else :
138- nondominated |= mapping [k ] <= v
139- nondominated_by_all &= nondominated
122+ nondominated |= edp_mapping [k ] <= v
123+ nondominated_by_all &= nondominated
140124
141125 if self ._pmapping_row_filter_function is not None :
142126 nondominated_by_all &= self ._pmapping_row_filter_function (mapping )
@@ -235,8 +219,8 @@ def join_strategy_2(
235219 filter_func = OptimalityThresholder (
236220 joined ,
237221 _pmapping_row_filter_function ,
238- spec .mapper ._metric_aggregator ,
239222 print_progress ,
223+ metrics ,
240224 )
241225 except Exception as e :
242226 if i == len (thresholds ) - 1 :
@@ -356,6 +340,8 @@ def clean_compress_and_join_pmappings(
356340
357341 joined = decompress_pmappings (joined , decompress_data )
358342
343+ _apply_edp_columns (joined .data , metrics )
344+
359345 for einsum_name in einsum2pmappings :
360346 col = f"{ einsum_name } <SEP>{ MAPPING_COLUMN } "
361347 joined .data [col ] = joined .data [col ].apply (
0 commit comments