22# Email: simon.blanke@yahoo.com
33# License: MIT License
44
5- import copy
6- import inspect
75import numpy as np
86import pandas as pd
97
108
119from .objective_function import ObjectiveFunction
1210from .hyper_gradient_conv import HyperGradientConv
11+ from .base_optimizer import BaseOptimizer
1312
1413
15- class TrafoClass :
16- def __init__ (self , * args , ** kwargs ):
17- pass
18-
19- def _convert_args2gfo (self , memory_warm_start ):
20- memory_warm_start = self .hg_conv .conv_memory_warm_start (memory_warm_start )
21-
22- return memory_warm_start
23-
24- def _positions2results (self , positions ):
25- results_dict = {}
26-
27- for para_name in self .conv .para_names :
28- values_list = self .s_space [para_name ]
29- pos_ = positions [para_name ].values
30- values_ = [values_list [idx ] for idx in pos_ ]
31- results_dict [para_name ] = values_
32-
33- results = pd .DataFrame .from_dict (results_dict )
34-
35- diff_list = np .setdiff1d (positions .columns , results .columns )
36- results [diff_list ] = positions [diff_list ]
37-
38- return results
39-
40- def _convert_results2hyper (self ):
41- self .eval_times = np .array (self ._optimizer .eval_times ).sum ()
42- self .iter_times = np .array (self ._optimizer .iter_times ).sum ()
43-
44- if self ._optimizer .best_para is not None :
45- value = self .hg_conv .para2value (self ._optimizer .best_para )
46- position = self .hg_conv .position2value (value )
47- best_para = self .hg_conv .value2para (position )
48-
49- self .best_para = best_para
50- else :
51- self .best_para = None
52-
53- self .best_score = self ._optimizer .best_score
54- self .positions = self ._optimizer .search_data
55-
56- self .search_data = self ._positions2results (self .positions )
57-
58- results_dd = self ._optimizer .search_data .drop_duplicates (
59- subset = self .s_space .dim_keys , keep = "first"
60- )
61- self .memory_values_df = results_dd [
62- self .s_space .dim_keys + ["score" ]
63- ].reset_index (drop = True )
64-
65-
66- class _BaseOptimizer_ (TrafoClass ):
14+ class HyperOptimizer (BaseOptimizer ):
6715 def __init__ (self , ** opt_params ):
6816 super ().__init__ ()
6917 self .opt_params = opt_params
@@ -104,6 +52,31 @@ def setup_search(
10452 else :
10553 self .verbosity = []
10654
55+ def convert_results2hyper (self ):
56+ self .eval_times = np .array (self .opt_algo .eval_times ).sum ()
57+ self .iter_times = np .array (self .opt_algo .iter_times ).sum ()
58+
59+ if self .opt_algo .best_para is not None :
60+ value = self .hg_conv .para2value (self .opt_algo .best_para )
61+ position = self .hg_conv .position2value (value )
62+ best_para = self .hg_conv .value2para (position )
63+
64+ self .best_para = best_para
65+ else :
66+ self .best_para = None
67+
68+ self .best_score = self .opt_algo .best_score
69+ self .positions = self .opt_algo .search_data
70+
71+ self .search_data = self .hg_conv .positions2results (self .positions )
72+
73+ results_dd = self .opt_algo .search_data .drop_duplicates (
74+ subset = self .s_space .dim_keys , keep = "first"
75+ )
76+ self .memory_values_df = results_dd [
77+ self .s_space .dim_keys + ["score" ]
78+ ].reset_index (drop = True )
79+
10780 def _setup_process (self , nth_process ):
10881 self .nth_process = nth_process
10982
@@ -118,33 +91,33 @@ def _setup_process(self, nth_process):
11891 self .opt_params ["warm_start_smbo" ]
11992 )
12093
121- self ._optimizer = self ._OptimizerClass (
94+ self .opt_algo = self ._OptimizerClass (
12295 search_space = search_space_positions ,
12396 initialize = initialize ,
12497 random_state = self .random_state ,
12598 nth_process = nth_process ,
12699 ** self .opt_params
127100 )
128101
129- self .conv = self ._optimizer .conv
102+ self .conv = self .opt_algo .conv
130103
131104 def search (self , nth_process ):
132105 self ._setup_process (nth_process )
133106
134107 gfo_wrapper_model = ObjectiveFunction (
135108 objective_function = self .objective_function ,
136- optimizer = self ._optimizer ,
109+ optimizer = self .opt_algo ,
137110 callbacks = self .callbacks ,
138111 catch = self .catch ,
139112 nth_process = self .nth_process ,
140113 )
141114 gfo_wrapper_model .pass_through = self .pass_through
142115
143- memory_warm_start = self ._convert_args2gfo (self .memory_warm_start )
116+ memory_warm_start = self .hg_conv . conv_memory_warm_start (self .memory_warm_start )
144117
145118 gfo_objective_function = gfo_wrapper_model (self .s_space ())
146119
147- self ._optimizer .search (
120+ self .opt_algo .search (
148121 objective_function = gfo_objective_function ,
149122 n_iter = self .n_iter ,
150123 max_time = self .max_time ,
@@ -155,5 +128,16 @@ def search(self, nth_process):
155128 verbosity = self .verbosity ,
156129 )
157130
158- self ._convert_results2hyper ()
159- self .p_bar = self ._optimizer .p_bar
131+ self .convert_results2hyper ()
132+
133+ self ._add_result_attributes (
134+ self .best_para ,
135+ self .best_score ,
136+ self .opt_algo .p_bar ._best_since_iter ,
137+ self .eval_times ,
138+ self .iter_times ,
139+ self .positions ,
140+ self .search_data ,
141+ self .memory_values_df ,
142+ self .opt_algo .random_seed ,
143+ )
0 commit comments