1- import heapq
21import random
3- import uuid
42from concurrent .futures import ThreadPoolExecutor , as_completed
53
64from brain .agent import agent
75from brain .agent_config import DEFAULT_CONFIG
86from brain .alpha_class import Alpha
7+ from brain .alpha_storage import Storage
98from brain .api import DEFAULT_CONFIG as API_DEFAULT_CONFIG
109from brain .api import BrainAPI
1110from brain .database import Database
@@ -20,23 +19,23 @@ def decay_hyperbolic(x, gamma=0.2, delta=0.1):
2019 return (gamma * x ) / (1 + delta * x )
2120
2221
23- def create_alpha_simulation (alphas_dict , alphas_categories ):
24- """Create a new alpha based on the given ID."""
22+ def get_score (alpha : Alpha ):
23+ if not alpha .visible :
24+ return float ("-inf" )
2525
26- def get_score (alpha_id ):
27- return (
28- alphas_dict [alpha_id ].fitness
29- + 1.5 * alphas_dict [alpha_id ].sharpe
30- - decay_hyperbolic (alphas_dict [alpha_id ].print_counter , gamma = 0.01 , delta = 0.02 )
31- )
26+ return (
27+ alpha .fitness
28+ + 1.5 * alpha .sharpe
29+ - decay_hyperbolic (alpha .print_counter , gamma = 0.01 , delta = 0.02 )
30+ )
3231
33- n_largest = 10
34- for cat in [ "passing" , "failing" ] :
35- alphas_categories [ cat ] = heapq . nlargest ( n_largest , alphas_categories [ cat ], key = get_score )
32+
33+ def create_alpha_simulation ( storage : Storage ) :
34+ """Create a new alpha based on the given ID."""
3635
3736 formatted_alphas = {
38- cat : "\n " .join (alphas_dict [ id ] .prompt_format () for id in alphas_categories [ cat ] )
39- for cat in alphas_categories . keys ()
37+ cat : "\n " .join (alpha .prompt_format () for alpha in storage . get_top_k ( cat , 10 ) )
38+ for cat in storage . categories
4039 }
4140
4241 if random .random () < 0.05 :
@@ -114,60 +113,80 @@ def monitor_alpha(response, alpha_config):
114113 }
115114
116115
117- def update_alphas_dict (alphas_dict , alphas_categories , stats , temp_id ):
116+ def update_alphas_dict (
117+ storage : Storage ,
118+ stats : dict ,
119+ temp_id : str ,
120+ ):
118121 """Update the alphas dictionary with the new stats."""
119- alphas_categories ["pending" ].remove (temp_id )
120- alphas_dict .pop (temp_id )
122+ storage .remove_pending_alpha (temp_id )
121123
122124 if stats ["alpha_id" ] is None :
123125 return
124126
125- alpha_id = stats ["alpha_id" ]
126- alphas_dict [alpha_id ] = Alpha .from_stats (stats )
127+ alpha = Alpha .from_stats (stats )
127128 try :
128- Database ().insert_alpha (alphas_dict [ alpha_id ] )
129+ Database ().insert_alpha (alpha )
129130 except Exception as e :
130131 print (f"Error during database insertion: { e } " )
131132 pass
132133
133- if alphas_dict [ alpha_id ] .short_count + alphas_dict [ alpha_id ] .long_count > 0 :
134+ if alpha .short_count + alpha .long_count > 0 :
134135 if (stats ["is_tests" ]["result" ] != "FAIL" ).all ():
135- alphas_categories [ "passing" ]. append ( alpha_id )
136+ storage . add_alpha ( alpha , "passing" )
136137 else :
137- alphas_categories [ "failing" ]. append ( alpha_id )
138+ storage . add_alpha ( alpha , "failing" )
138139
139- return alphas_dict [alpha_id ]
140+ return alpha
141+
142+
143+ def set_warm_start_alphas (storage : Storage ) -> None :
144+ """Initialize alphas_dict with warm start alphas from the database."""
145+ try :
146+ alphas = Database ().k_best_alphas (
147+ metric = "sharpe" ,
148+ top_k = 100 ,
149+ min_fitness = 1.0 ,
150+ max_self_corr = 0.6 ,
151+ )
152+
153+ alphas = random .sample (alphas , min (10 , len (alphas )))
154+ for alpha in alphas :
155+ alpha .hide_after = 30
156+ storage .add_alpha (alpha , "failing" )
157+
158+ except Exception as e :
159+ print (f"Error during database query: { e } " )
140160
141161
142162def main ():
143163 """Main function to run the agent."""
144- alphas_dict = {}
145- alphas_categories = {
146- "passing" : [],
147- "failing" : [],
148- "pending" : [],
149- }
164+ storage = Storage (score_func = get_score , max_size = 50 )
165+
166+ set_warm_start_alphas (storage )
150167
151168 with ThreadPoolExecutor (max_workers = MAX_WORKERS ) as pool :
152169 live_jobs = {}
153170
154171 for _ in range (MAX_WORKERS ):
155172 # Start a new alpha simulation
156- response , alpha_config = create_alpha_simulation (alphas_dict , alphas_categories )
157-
158- # Generate a unique ID for the alpha
159- temp_id = str (uuid .uuid4 ())
160- alphas_categories ["pending" ].append (temp_id )
161- alphas_dict [temp_id ] = Alpha .from_config (alpha_config )
162- live_jobs [pool .submit (monitor_alpha , response , alpha_config )] = (temp_id , alpha_config )
173+ response , alpha_config = create_alpha_simulation (storage )
174+
175+ # Create a temporary alpha configuration
176+ alpha = Alpha .from_config (alpha_config )
177+ storage .add_alpha (alpha , "pending" )
178+ live_jobs [pool .submit (monitor_alpha , response , alpha_config )] = (
179+ alpha .alpha_id ,
180+ alpha_config ,
181+ )
163182
164183 while live_jobs :
165184 for job in as_completed (live_jobs ):
166185 # Update alphas_dict with the results
167186 temp_id , alpha_config = live_jobs .pop (job ) # remove from “running” set
168187 stats = job .result ()
169188 print (f"Stats: { stats } " )
170- alpha = update_alphas_dict (alphas_dict , alphas_categories , stats , temp_id )
189+ alpha = update_alphas_dict (storage , stats , temp_id )
171190
172191 # Start a new alpha simulation
173192 if alpha is not None and alpha .alpha_id is not None and alpha .fitness < - 0.5 :
@@ -176,12 +195,11 @@ def main():
176195 alpha_config = {** alpha_config , "regular" : regular }
177196 response = BrainAPI .start_simulation (alpha_config )
178197 else :
179- response , alpha_config = create_alpha_simulation (alphas_dict , alphas_categories )
180- # Generate a unique ID for the alpha
181- temp_id = str (uuid .uuid4 ())
182- alphas_categories ["pending" ].append (temp_id )
183- alphas_dict [temp_id ] = Alpha .from_config (alpha_config )
198+ response , alpha_config = create_alpha_simulation (storage )
199+ # TODO: Turn this into a method + stop using alpha_config
200+ alpha = Alpha .from_config (alpha_config )
201+ storage .add_alpha (alpha , "pending" )
184202 live_jobs [pool .submit (monitor_alpha , response , alpha_config )] = (
185- temp_id ,
203+ alpha . alpha_id ,
186204 alpha_config ,
187205 )
0 commit comments