Skip to content

Commit fb7b711

Browse files
committed
Adding new alpha storage class + warm start from DB
1 parent 84403f0 commit fb7b711

7 files changed

Lines changed: 183 additions & 52 deletions

File tree

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,23 @@ pre-commit install
5353
```bash
5454
./db_migrate.sh up
5555
```
56+
57+
## Running Using Docker
58+
59+
Build
60+
61+
```bash
62+
docker build -t brain:latest .
63+
```
64+
65+
Run
66+
67+
```bash
68+
docker run -d --name brain-container --restart=always brain:latest
69+
```
70+
71+
Monitor
72+
73+
```bash
74+
docker logs -f brain-container
75+
```

brain/alpha_class.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import sqlite3
3+
import uuid
34
from dataclasses import asdict, dataclass, field
45
from typing import Optional
56

@@ -9,7 +10,11 @@ class Alpha:
910
# database-generated primary key (None before insert)
1011
alpha_id: str = None
1112

13+
# Lifecycle properties
14+
is_temporary: bool = False
1215
print_counter: int = 0
16+
visible: bool = True
17+
hide_after: Optional[int] = None
1318

1419
# user-supplied columns
1520
regular: str = ""
@@ -40,11 +45,17 @@ class Alpha:
4045
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc).isoformat(" ")
4146
)
4247

48+
def _increase_print_counter(self):
49+
"""Increase the print counter."""
50+
self.print_counter += 1
51+
if self.hide_after is not None and self.print_counter > self.hide_after:
52+
self.visible = False
53+
4354
def prompt_format(self) -> str:
4455
"""Format alpha data for the prompt."""
45-
self.print_counter += 1
56+
self._increase_print_counter()
4657

47-
if self.alpha_id is None:
58+
if self.alpha_id is None or self.is_temporary:
4859
return f"**Expression:** `{self.regular}`"
4960

5061
number_of_trades = (
@@ -69,14 +80,14 @@ def prompt_format(self) -> str:
6980
def as_dict(self) -> dict:
7081
"""Convert the Alpha instance to a dictionary compatible with DB."""
7182
data = asdict(self)
72-
for exclude in ["print_counter"]:
83+
for exclude in ["print_counter", "visible", "hide_after"]:
7384
data.pop(exclude, None)
7485
data["failing_tests"] = ",".join(data["failing_tests"])
7586
return data
7687

7788
@classmethod
7889
def from_config(cls, config: dict) -> "Alpha":
79-
"""Build an Alpha from a dictionary."""
90+
"""Build a temporary Alpha from a dictionary."""
8091
config_cols = [
8192
"regular",
8293
"region",
@@ -89,7 +100,9 @@ def from_config(cls, config: dict) -> "Alpha":
89100
"nan_handling",
90101
"unit_handling",
91102
]
92-
return cls(**{k: config.get(k) for k in config_cols})
103+
return cls(
104+
**{k: config.get(k) for k in config_cols}, alpha_id=str(uuid.uuid4()), is_temporary=True
105+
)
93106

94107
@classmethod
95108
def from_stats(cls, stats: dict) -> "Alpha":

brain/alpha_storage.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import heapq
2+
from typing import Callable
3+
4+
from brain.alpha_class import Alpha
5+
6+
7+
class Storage:
8+
def __init__(self, score_func: Callable[[Alpha], float], max_size: int = 50):
9+
"""Initialize the storage with empty data and categories.
10+
11+
Args:
12+
max_size: Maximum number of alphas to store in each category.
13+
"""
14+
self.score_func = score_func
15+
self.max_size = max_size
16+
self.data = {}
17+
self.categories = {
18+
"passing": [],
19+
"failing": [],
20+
"pending": [],
21+
}
22+
23+
def __getitem__(self, alpha_id: str) -> Alpha:
24+
return self.data.get(alpha_id)
25+
26+
def get_top_k(self, category: str, k: int = 10) -> list[Alpha]:
27+
"""Get the top k alphas from a specific category."""
28+
if category not in self.categories:
29+
raise ValueError(f"Invalid category: {category}")
30+
31+
return [self.data[alpha_id] for alpha_id in self.categories[category][:k]]
32+
33+
def add_alpha(self, alpha: Alpha, category: str) -> None:
34+
"""Add an alpha to the storage in the specified category."""
35+
if category not in self.categories:
36+
raise ValueError(f"Invalid category: {category}")
37+
38+
self.data[alpha.alpha_id] = alpha
39+
if alpha.alpha_id not in self.categories[category]:
40+
self._append_to_category(alpha.alpha_id, category)
41+
42+
def remove_pending_alpha(self, alpha_id: str) -> None:
43+
"""Remove an alpha from the pending category."""
44+
if alpha_id in self.categories["pending"]:
45+
self.categories["pending"].remove(alpha_id)
46+
self.data.pop(alpha_id, None)
47+
48+
def _score(self, alpha_id: str) -> float:
49+
return self.score_func(self.data[alpha_id])
50+
51+
def _append_to_category(self, alpha_id: str, category: str) -> None:
52+
"""Append an alpha to a specific category."""
53+
self.categories[category].append(alpha_id)
54+
if category != "pending":
55+
self.categories[category] = heapq.nlargest(
56+
self.max_size, self.categories[category], key=self._score
57+
)

brain/database.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,29 @@ def find_by_code(self, code: str, neutralization: str, delay: int) -> list[Alpha
4040
rows = self.cursor.fetchall()
4141
return [Alpha.from_row(r) for r in rows]
4242

43+
def k_best_alphas(
44+
self,
45+
metric: str = "sharpe",
46+
top_k: int = 100,
47+
min_fitness: float = 1.0,
48+
max_self_corr: float = 0.6,
49+
) -> list[Alpha]:
50+
"""Find best performing alphas by certain metric."""
51+
# Make sure `metric` is a valid column in your `alphas` table!
52+
sql = f"""
53+
SELECT *
54+
FROM alphas
55+
WHERE {metric} IS NOT NULL
56+
AND fitness > %s
57+
AND self_correlation < %s
58+
ORDER BY {metric} DESC
59+
LIMIT %s
60+
"""
61+
params = (min_fitness, max_self_corr, top_k)
62+
self.cursor.execute(sql, params)
63+
rows = self.cursor.fetchall()
64+
return [Alpha.from_row(r) for r in rows]
65+
4366
def close(self):
4467
self.cursor.close()
4568
self.conn.close()

brain/search_algorithm.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import heapq
21
import random
3-
import uuid
42
from concurrent.futures import ThreadPoolExecutor, as_completed
53

64
from brain.agent import agent
75
from brain.agent_config import DEFAULT_CONFIG
86
from brain.alpha_class import Alpha
7+
from brain.alpha_storage import Storage
98
from brain.api import DEFAULT_CONFIG as API_DEFAULT_CONFIG
109
from brain.api import BrainAPI
1110
from brain.database import Database
@@ -20,23 +19,23 @@ def decay_hyperbolic(x, gamma=0.2, delta=0.1):
2019
return (gamma * x) / (1 + delta * x)
2120

2221

23-
def create_alpha_simulation(alphas_dict, alphas_categories):
24-
"""Create a new alpha based on the given ID."""
22+
def get_score(alpha: Alpha):
23+
if not alpha.visible:
24+
return float("-inf")
2525

26-
def get_score(alpha_id):
27-
return (
28-
alphas_dict[alpha_id].fitness
29-
+ 1.5 * alphas_dict[alpha_id].sharpe
30-
- decay_hyperbolic(alphas_dict[alpha_id].print_counter, gamma=0.01, delta=0.02)
31-
)
26+
return (
27+
alpha.fitness
28+
+ 1.5 * alpha.sharpe
29+
- decay_hyperbolic(alpha.print_counter, gamma=0.01, delta=0.02)
30+
)
3231

33-
n_largest = 10
34-
for cat in ["passing", "failing"]:
35-
alphas_categories[cat] = heapq.nlargest(n_largest, alphas_categories[cat], key=get_score)
32+
33+
def create_alpha_simulation(storage: Storage):
34+
"""Create a new alpha based on the given ID."""
3635

3736
formatted_alphas = {
38-
cat: "\n".join(alphas_dict[id].prompt_format() for id in alphas_categories[cat])
39-
for cat in alphas_categories.keys()
37+
cat: "\n".join(alpha.prompt_format() for alpha in storage.get_top_k(cat, 10))
38+
for cat in storage.categories
4039
}
4140

4241
if random.random() < 0.05:
@@ -114,60 +113,80 @@ def monitor_alpha(response, alpha_config):
114113
}
115114

116115

117-
def update_alphas_dict(alphas_dict, alphas_categories, stats, temp_id):
116+
def update_alphas_dict(
117+
storage: Storage,
118+
stats: dict,
119+
temp_id: str,
120+
):
118121
"""Update the alphas dictionary with the new stats."""
119-
alphas_categories["pending"].remove(temp_id)
120-
alphas_dict.pop(temp_id)
122+
storage.remove_pending_alpha(temp_id)
121123

122124
if stats["alpha_id"] is None:
123125
return
124126

125-
alpha_id = stats["alpha_id"]
126-
alphas_dict[alpha_id] = Alpha.from_stats(stats)
127+
alpha = Alpha.from_stats(stats)
127128
try:
128-
Database().insert_alpha(alphas_dict[alpha_id])
129+
Database().insert_alpha(alpha)
129130
except Exception as e:
130131
print(f"Error during database insertion: {e}")
131132
pass
132133

133-
if alphas_dict[alpha_id].short_count + alphas_dict[alpha_id].long_count > 0:
134+
if alpha.short_count + alpha.long_count > 0:
134135
if (stats["is_tests"]["result"] != "FAIL").all():
135-
alphas_categories["passing"].append(alpha_id)
136+
storage.add_alpha(alpha, "passing")
136137
else:
137-
alphas_categories["failing"].append(alpha_id)
138+
storage.add_alpha(alpha, "failing")
138139

139-
return alphas_dict[alpha_id]
140+
return alpha
141+
142+
143+
def set_warm_start_alphas(storage: Storage) -> None:
144+
"""Initialize alphas_dict with warm start alphas from the database."""
145+
try:
146+
alphas = Database().k_best_alphas(
147+
metric="sharpe",
148+
top_k=100,
149+
min_fitness=1.0,
150+
max_self_corr=0.6,
151+
)
152+
153+
alphas = random.sample(alphas, min(10, len(alphas)))
154+
for alpha in alphas:
155+
alpha.hide_after = 30
156+
storage.add_alpha(alpha, "failing")
157+
158+
except Exception as e:
159+
print(f"Error during database query: {e}")
140160

141161

142162
def main():
143163
"""Main function to run the agent."""
144-
alphas_dict = {}
145-
alphas_categories = {
146-
"passing": [],
147-
"failing": [],
148-
"pending": [],
149-
}
164+
storage = Storage(score_func=get_score, max_size=50)
165+
166+
set_warm_start_alphas(storage)
150167

151168
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
152169
live_jobs = {}
153170

154171
for _ in range(MAX_WORKERS):
155172
# Start a new alpha simulation
156-
response, alpha_config = create_alpha_simulation(alphas_dict, alphas_categories)
157-
158-
# Generate a unique ID for the alpha
159-
temp_id = str(uuid.uuid4())
160-
alphas_categories["pending"].append(temp_id)
161-
alphas_dict[temp_id] = Alpha.from_config(alpha_config)
162-
live_jobs[pool.submit(monitor_alpha, response, alpha_config)] = (temp_id, alpha_config)
173+
response, alpha_config = create_alpha_simulation(storage)
174+
175+
# Create a temporary alpha configuration
176+
alpha = Alpha.from_config(alpha_config)
177+
storage.add_alpha(alpha, "pending")
178+
live_jobs[pool.submit(monitor_alpha, response, alpha_config)] = (
179+
alpha.alpha_id,
180+
alpha_config,
181+
)
163182

164183
while live_jobs:
165184
for job in as_completed(live_jobs):
166185
# Update alphas_dict with the results
167186
temp_id, alpha_config = live_jobs.pop(job) # remove from “running” set
168187
stats = job.result()
169188
print(f"Stats: {stats}")
170-
alpha = update_alphas_dict(alphas_dict, alphas_categories, stats, temp_id)
189+
alpha = update_alphas_dict(storage, stats, temp_id)
171190

172191
# Start a new alpha simulation
173192
if alpha is not None and alpha.alpha_id is not None and alpha.fitness < -0.5:
@@ -176,12 +195,11 @@ def main():
176195
alpha_config = {**alpha_config, "regular": regular}
177196
response = BrainAPI.start_simulation(alpha_config)
178197
else:
179-
response, alpha_config = create_alpha_simulation(alphas_dict, alphas_categories)
180-
# Generate a unique ID for the alpha
181-
temp_id = str(uuid.uuid4())
182-
alphas_categories["pending"].append(temp_id)
183-
alphas_dict[temp_id] = Alpha.from_config(alpha_config)
198+
response, alpha_config = create_alpha_simulation(storage)
199+
# TODO: Turn this into a method + stop using alpha_config
200+
alpha = Alpha.from_config(alpha_config)
201+
storage.add_alpha(alpha, "pending")
184202
live_jobs[pool.submit(monitor_alpha, response, alpha_config)] = (
185-
temp_id,
203+
alpha.alpha_id,
186204
alpha_config,
187205
)

brain/tools/simulation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import time
44
from typing import Annotated
55

6-
from langchain_core.messages import ToolMessage
76
from langchain_core.runnables import RunnableConfig
87
from langchain_core.tools import InjectedToolCallId, tool
9-
from langgraph.graph import END
108
from langgraph.types import Command
119

1210
from brain.agent_config import get_universe_config

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def parse_requirements(file_name):
5050
platforms=["Windows", "Linux", "Solaris", "Mac OS-X", "Unix"],
5151
python_requires=">=3.9",
5252
install_requires=REQUIREMENTS,
53+
include_package_data=True,
54+
package_data={"brain": ["tools/data/*"]},
5355
zip_safe=False,
5456
entry_points={
5557
"console_scripts": [

0 commit comments

Comments
 (0)