|
1 | 1 | """Database module for storing and managing alphas.""" |
2 | 2 |
|
3 | | -import sqlite3 |
4 | | -from pathlib import Path |
| 3 | +import os |
5 | 4 |
|
6 | 5 | import numpy as np |
| 6 | +import psycopg |
| 7 | +from dotenv import load_dotenv |
| 8 | +from psycopg.rows import dict_row |
7 | 9 |
|
8 | 10 | from brain.alpha_class import Alpha |
9 | 11 |
|
10 | | -DB_PATH = Path(__file__).parent.parent / "alphas_database.db" |
| 12 | +load_dotenv() |
11 | 13 |
|
12 | | - |
13 | | -sqlite3.register_adapter(np.int64, lambda x: int(x)) |
| 14 | +psycopg.adapters.register_dumper(np.int64, psycopg.types.numeric.IntDumper) |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class Database: |
17 | | - def __init__(self, db_path=DB_PATH): |
18 | | - self.conn = sqlite3.connect(db_path) |
19 | | - self.conn.row_factory = sqlite3.Row |
| 18 | + def __init__(self, db_url: str = None): |
| 19 | + """Initialize the database connection.""" |
| 20 | + self.db_url = db_url or os.environ.get("DATABASE_URL") |
| 21 | + self.conn = psycopg.connect(self.db_url, row_factory=dict_row) |
| 22 | + self.conn.set_autocommit(True) |
20 | 23 | self.cursor = self.conn.cursor() |
21 | | - self._create_table() |
22 | | - |
23 | | - def _create_table(self): |
24 | | - schema = """ |
25 | | - CREATE TABLE IF NOT EXISTS alphas ( |
26 | | - alpha_id TEXT PRIMARY KEY, |
27 | | - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
28 | | - regular TEXT NOT NULL, |
29 | | - region TEXT NOT NULL, |
30 | | - universe TEXT NOT NULL, |
31 | | - decay INTEGER NOT NULL, |
32 | | - delay INTEGER NOT NULL, |
33 | | - truncation REAL NOT NULL, |
34 | | - neutralization TEXT NOT NULL, |
35 | | - pasteurization TEXT NOT NULL, |
36 | | - nan_handling TEXT NOT NULL, |
37 | | - unit_handling TEXT NOT NULL, |
38 | | - fitness REAL, |
39 | | - sharpe REAL, |
40 | | - returns REAL, |
41 | | - drawdown REAL, |
42 | | - turnover REAL, |
43 | | - margin REAL, |
44 | | - long_count INTEGER, |
45 | | - short_count INTEGER, |
46 | | - self_correlation REAL, |
47 | | - failing_tests TEXT |
48 | | - ); |
49 | | - """ |
50 | | - self.cursor.execute(schema) |
51 | | - self.conn.commit() |
52 | 24 |
|
53 | 25 | def insert_alpha(self, alpha: Alpha) -> int: |
54 | 26 | """Insert a Alpha class instance into the table.""" |
55 | 27 | record = alpha.as_dict() |
56 | 28 | cols = ", ".join(record) |
57 | | - bangs = ", ".join("?" for _ in record) |
58 | | - |
| 29 | + bangs = ", ".join("%s" for _ in record) |
59 | 30 | sql = f"INSERT INTO alphas ({cols}) VALUES ({bangs})" |
60 | 31 | self.cursor.execute(sql, tuple(record.values())) |
61 | | - self.conn.commit() |
62 | 32 |
|
63 | 33 | def find_by_code(self, code: str, neutralization: str, delay: int) -> list[Alpha]: |
64 | 34 | """Find an alpha by its code.""" |
65 | 35 | sql = ( |
66 | | - "SELECT * FROM alphas WHERE regular = ? AND neutralization = ? AND delay = ? " |
| 36 | + "SELECT * FROM alphas WHERE regular = %s AND neutralization = %s AND delay = %s " |
67 | 37 | "ORDER BY created_at DESC" |
68 | 38 | ) |
69 | | - rows = self.cursor.execute(sql, (code, neutralization, delay)) |
| 39 | + self.cursor.execute(sql, (code, neutralization, delay)) |
| 40 | + rows = self.cursor.fetchall() |
70 | 41 | return [Alpha.from_row(r) for r in rows] |
71 | 42 |
|
72 | 43 | def close(self): |
| 44 | + self.cursor.close() |
73 | 45 | self.conn.close() |
0 commit comments