Skip to content

Commit 7a9b8ab

Browse files
committed
Wroking version of basic fine-tuning
1 parent 48beb83 commit 7a9b8ab

3 files changed

Lines changed: 38 additions & 20 deletions

File tree

brain/api.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,28 @@
2323

2424

2525
class BrainAPI:
26-
"""Class for BrainAPI session."""
26+
"""Class for communicating with BrainAPI session."""
27+
28+
_instance = None
2729

2830
def __init__(self):
2931
self.s = requests.Session()
3032
self._authenticate()
3133

34+
def __new__(cls):
35+
# Singleton pattern to ensure only one instance of BrainAPI exists.
36+
if cls._instance is None:
37+
cls._instance = super(BrainAPI, cls).__new__(cls)
38+
return cls._instance
39+
3240
def get_simulation_result_json(self, alpha_id):
3341
"""Get result of simulation as JSON dictionary."""
3442
return self._request("GET", "/alphas/" + alpha_id).json()
3543

3644
def start_simulation(self, simulate_data):
3745
"""Start simulation of provided alpha."""
38-
return self._request("POST", "/simulations", json=simulate_data)
46+
self.check_session_timeout()
47+
return self.s.post(f"{BRAIN_API_URL}/simulations", json=simulate_data)
3948

4049
def check_session_timeout(self):
4150
"""Check session time out and refresh session if necessary."""
@@ -334,6 +343,7 @@ def get_specified_alpha_stats(
334343
.drop_duplicates(subset=["test"], keep="last")
335344
.reset_index(drop=True)
336345
)
346+
337347
if check_prod_corr and not check_submission:
338348
prod_corr_test = self.check_prod_corr_test(alpha_id)
339349
is_tests = (
@@ -527,7 +537,7 @@ def get_datafields(
527537
datafields_df = pd.DataFrame(datafields_list_flat)
528538
return datafields_df
529539

530-
def _request(self, method: str, endpoint: str, max_retries: int = 20, **kwargs):
540+
def _request(self, method: str, endpoint: str, max_retries: int = 120, **kwargs):
531541
"""Generic request method with retry logic."""
532542
while max_retries > 0:
533543
response = self.s.request(method, BRAIN_API_URL + endpoint, **kwargs)

brain/fine_tune.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from itertools import product
44

55
from brain.alpha_helpers import generate_alpha
6+
from brain.api import BrainAPI
67

78

89
def generate_alpha_grid(regular: str):
@@ -26,5 +27,24 @@ def generate_alpha_grid(regular: str):
2627
return alphas
2728

2829

29-
def fine_tune_alpha():
30-
pass
30+
def get_fitness(result):
31+
"""Get the fitness of the alpha from result."""
32+
stats = result["is_stats"]
33+
if "fitness" in stats:
34+
return stats["fitness"][0]
35+
36+
return -1
37+
38+
39+
def fine_tune_alpha(regular: str):
40+
alphas = generate_alpha_grid(regular)
41+
results = BrainAPI().simulate_alpha_list(alphas)
42+
sorted_results = sorted(results, key=get_fitness, reverse=True)
43+
44+
print("Best alpha parameters:")
45+
for i, result in enumerate(sorted_results):
46+
print(f"Rank {i + 1}:")
47+
print(f"Alpha: {result['alpha']}")
48+
print(f"Fitness: {get_fitness(result)}")
49+
print(f"Parameters: {result['simulate_data']}")
50+
print()

brain/main.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,12 @@
22

33
from brain.alpha_helpers import generate_alpha
44
from brain.api import BrainAPI
5+
from brain.fine_tune import fine_tune_alpha
56

67

78
def main():
8-
api = BrainAPI()
9-
10-
k = [
11-
"vwap * 2",
12-
"open * close",
13-
"high * low",
14-
"vwap * 3",
15-
"open * close",
16-
"high * low",
17-
]
18-
alpha_list = [generate_alpha(x) for x in k]
19-
20-
results = api.simulate_alpha_list(alpha_list)
21-
print("Simulation results:")
22-
print(results)
9+
alpha = "-ts_mean(snt_buzz_bfl, 2) * ts_zscore(snt_buzz, 20)"
10+
fine_tune_alpha(alpha)
2311

2412

2513
if __name__ == "__main__":

0 commit comments

Comments
 (0)