Skip to content

Commit a524c3d

Browse files
committed
Working on fine-tuning script
1 parent 5d57a58 commit a524c3d

7 files changed

Lines changed: 191 additions & 141 deletions

File tree

brain/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class CustomState(AgentState):
5959
tools = [
6060
submit_alpha,
6161
describe_operators,
62-
# search_datafields,
62+
search_datafields,
6363
get_random_datafields,
6464
get_random_idea,
6565
]

brain/agent_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
DEFAULT_CONFIG = {
44
"region": "USA",
55
"universe": "TOP3000",
6-
"neutralization": "SECTOR",
7-
"decay": 10,
6+
"neutralization": "INDUSTRY",
7+
"decay": 5,
88
"delay": 1,
99
}
1010

brain/alpha_class.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,12 @@ def from_stats(cls, stats: dict) -> "Alpha":
180180
"neutralization",
181181
"pasteurization",
182182
]
183-
is_stats = stats["is_stats"].iloc[0]
183+
184+
if "train" in stats:
185+
is_stats = stats["train"]
186+
else:
187+
is_stats = stats["is_stats"].iloc[0]
188+
184189
is_tests_df = stats["is_tests"]
185190
self_corr = is_tests_df[is_tests_df["name"] == "SELF_CORRELATION"].iloc[0]["value"]
186191

brain/fine_tune.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,89 @@
11
"""Methods for fine-tuning parameters of alpha."""
22

3-
from itertools import product
3+
import random
44

5+
from brain.agent import agent
6+
from brain.agent_config import DEFAULT_CONFIG
57
from brain.alpha_class import Alpha
6-
from brain.api import BrainAPI
7-
8-
9-
def generate_alpha_grid(regular: str):
10-
"""Generate a grid of alpha parameters for fine-tuning."""
11-
# Define the grid of parameters to explore
12-
param_options = {
13-
# "universe": ["TOP3000", "TOP1000", "TOP500", "TOP200", "TOPSP500"],
14-
"universe": ["TOP3000", "TOP1000", "TOP500"],
15-
"neutralization": ["INDUSTRY", "SECTOR", "MARKET", "NONE", "SUBINDUSTRY"],
16-
"decay": [4, 8, 16, 32],
17-
"truncation": [0.01, 0.05, 0.1],
18-
# "pasteurization": ["ON", "OFF"],
19-
"pasteurization": ["ON"],
8+
from brain.alpha_storage import Storage
9+
from brain.genetic_algorithm import genetic_algorithm
10+
from brain.score import get_score
11+
from brain.tools.ideas import get_random_idea
12+
from brain.tools.simulation import StopException
13+
14+
MAIN_ALPHA = Alpha(
15+
regular="ts_corr(fnd6_newqv1300_lltq, fnd6_newqv1300_aociotherq, 40) * zscore(ts_mean(pcr_vol_120, 40))",
16+
)
17+
18+
19+
def create_alpha_simulation(storage: Storage):
20+
"""Create a new alpha based on the given ID."""
21+
22+
formatted_alphas = {
23+
cat: "\n".join(alpha.prompt_format() for alpha in storage.get_top_k(cat, 10))
24+
for cat in storage.categories
2025
}
2126

22-
# Generate all combinations of parameters
23-
combinations = list(product(*param_options.values()))
24-
params = [dict(zip(param_options.keys(), values)) for values in combinations]
25-
alphas = [Alpha.create_alpha(regular=regular, **p) for p in params]
27+
if random.random() < 0.05:
28+
prompt = "Create a completely new alpha by random data fields."
29+
else:
30+
prompt = f"""
31+
Your task is to fine-tune the parameters of the following alpha:
32+
{MAIN_ALPHA.prompt_format()}
33+
Create a new alphas by adding or removing data fields, changing parameters, or modifying the logic.
34+
You can add/remove data fields, change parameters, add operators like neutralization
35+
or modify the logic.
36+
HOWEVER, THE CORE LOGIC OF THE ALPHA SHOULD REMAIN THE SAME.
37+
38+
PASSING
39+
-------
40+
{formatted_alphas['passing']}
41+
42+
FAILING
43+
-------
44+
{formatted_alphas['failing']}
45+
46+
PENDING
47+
-------
48+
{formatted_alphas['pending']}
49+
50+
{get_random_idea() if random.random() < 0.3 else ''}
51+
"""
2652

27-
return alphas
53+
print(f"Prompt:\n{prompt}")
2854

55+
alphas_store = []
56+
while not alphas_store:
57+
try:
58+
agent.invoke(
59+
{
60+
"messages": [
61+
{
62+
"role": "user",
63+
"content": prompt,
64+
}
65+
]
66+
},
67+
config={
68+
"recursion_limit": 100,
69+
"configurable": {
70+
**DEFAULT_CONFIG,
71+
"alphas": alphas_store,
72+
},
73+
},
74+
)
75+
except StopException:
76+
continue
2977

30-
def get_fitness(result):
31-
"""Get the fitness of the alpha from result."""
32-
stats = result["is_stats"]
33-
if "fitness" in stats:
34-
return stats["fitness"][0]
78+
print(f"Alphas store: {alphas_store}")
79+
return alphas_store[-1]
3580

36-
return -1
3781

82+
def main():
83+
"""Main function to run the agent."""
84+
storage = Storage(score_func=get_score, max_size=50)
85+
genetic_algorithm(storage, create_alpha_simulation)
3886

39-
def fine_tune_alpha(regular: str):
40-
alphas = generate_alpha_grid(regular)
41-
results = BrainAPI.simulate_alpha_list(alphas)
42-
sorted_results = sorted(results, key=get_fitness, reverse=True)
4387

44-
print("Best alpha parameters:")
45-
for i, result in enumerate(sorted_results):
46-
print(f"Rank {i + 1}:")
47-
print(f"Alpha: {result['alpha']}")
48-
print(f"Fitness: {get_fitness(result)}")
49-
print(f"Parameters: {result['simulate_data']}")
50-
print()
88+
if __name__ == "__main__":
89+
main()

brain/genetic_algorithm.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from concurrent.futures import ThreadPoolExecutor, as_completed
2+
from typing import Callable
3+
4+
from requests import Response
5+
6+
from brain.alpha_class import Alpha
7+
from brain.alpha_storage import Storage
8+
from brain.api import DEFAULT_CONFIG as API_DEFAULT_CONFIG
9+
from brain.api import BrainAPI
10+
from brain.database import Database
11+
12+
MAX_WORKERS = 3
13+
14+
15+
def genetic_algorithm(storage: Storage, create_alpha: Callable[[Storage], tuple[Response, Alpha]]):
16+
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
17+
live_jobs = {}
18+
19+
# Make initial alpha simulations
20+
for _ in range(MAX_WORKERS):
21+
response, alpha = create_alpha(storage)
22+
storage.add_alpha(alpha, "pending")
23+
live_jobs[pool.submit(_monitor_alpha, response, alpha)] = alpha
24+
25+
# Wait for jobs to complete and start new ones
26+
while live_jobs:
27+
for job in as_completed(live_jobs):
28+
# Update storage with the results
29+
alpha = live_jobs.pop(job)
30+
stats = job.result()
31+
print(f"Stats: {stats}")
32+
alpha = _update_alphas_storage(storage, stats, alpha.alpha_id)
33+
34+
# Start a new alpha simulation
35+
if alpha is not None and alpha.alpha_id is not None and alpha.fitness < -0.5:
36+
split = alpha.regular.split(";")
37+
regular = f'{";".join(split[:-1])}{";" if len(split) > 1 else ""}-({split[-1]})'
38+
39+
new_alpha = alpha.replace(regular=regular)
40+
response = BrainAPI.start_simulation(
41+
new_alpha.get_simulation_data(test_period="P1Y0M0D")
42+
)
43+
else:
44+
response, new_alpha = create_alpha(storage)
45+
46+
storage.add_alpha(new_alpha, "pending")
47+
live_jobs[pool.submit(_monitor_alpha, response, new_alpha)] = new_alpha
48+
49+
50+
def _monitor_alpha(response, alpha):
51+
"""Monitor the alpha simulation."""
52+
try:
53+
simulation_result = BrainAPI.simulation_progress(response)
54+
if not simulation_result["completed"]:
55+
return {
56+
"alpha_id": None,
57+
"simulate_data": alpha.get_simulation_data(),
58+
"error": simulation_result["error"],
59+
}
60+
61+
BrainAPI.set_alpha_properties(simulation_result["result"]["id"])
62+
return BrainAPI.get_specified_alpha_stats(
63+
simulation_result["result"]["id"], alpha.get_simulation_data(), **API_DEFAULT_CONFIG
64+
)
65+
except Exception as e:
66+
print(f"Error during obtaining results: {e}")
67+
if isinstance(e, (ConnectionError)):
68+
BrainAPI._new_session()
69+
70+
return {
71+
"alpha_id": None,
72+
"simulate_data": alpha.get_simulation_data(),
73+
"error": str(e),
74+
}
75+
76+
77+
def _update_alphas_storage(
78+
storage: Storage,
79+
stats: dict,
80+
old_id: str,
81+
):
82+
"""Update the alphas dictionary with the new stats."""
83+
storage.remove_pending_alpha(old_id)
84+
85+
if stats["alpha_id"] is None:
86+
return
87+
88+
alpha = Alpha.from_stats(stats)
89+
try:
90+
Database().insert_alpha(alpha)
91+
except Exception as e:
92+
print(f"Error during database insertion: {e}")
93+
pass
94+
95+
if alpha.short_count + alpha.long_count > 0:
96+
if (stats["is_tests"]["result"] != "FAIL").all():
97+
storage.add_alpha(alpha, "passing")
98+
else:
99+
storage.add_alpha(alpha, "failing")
100+
101+
return alpha

0 commit comments

Comments
 (0)