-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathindicators_storage.py
More file actions
285 lines (232 loc) · 11.7 KB
/
indicators_storage.py
File metadata and controls
285 lines (232 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import pandas as pd
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from rsi_component import calculate_rsi, detect_rsi_divergence
from supertrend_component import calculate_supertrend, get_trend_signals
from volume_indicators import calculate_volume_indicators
from stock_cache import StockDataCache
@dataclass
class TechnicalIndicators:
"""技术指标数据结构"""
date: str
rsi14: Optional[float] = None
ma10: Optional[float] = None
daily_change_pct: Optional[float] = None # 日涨幅百分比
# 统一的 SuperTrend 参数
upper_band: Optional[float] = None # SuperTrend 上轨
lower_band: Optional[float] = None # SuperTrend 下轨
trend: Optional[int] = None # 趋势方向:1=上涨, -1=下跌, 0=中性
# 成交量相关指标
volume: Optional[float] = None # 当日成交量
vol_ratio: Optional[float] = None # 量比(当日成交量/过去5日平均成交量)
# 成交量指标增强
vol_20d_avg: Optional[float] = None # 20日平均成交量
vol_20d_max: Optional[float] = None # 20日最大成交量
vol_50d_min: Optional[float] = None # 50日最小成交量
is_high_vol_bar: Optional[bool] = None # 是否为放量
is_sky_vol_bar: Optional[bool] = None # 是否为爆量
is_low_vol_bar: Optional[bool] = None # 是否为极致缩量
near_20d_high: Optional[bool] = None # 是否接近20日新高
price_condition: Optional[bool] = None # 是否满足价格条件
@dataclass
class RSIDivergence:
"""RSI背离信号数据结构"""
date: str
prev_date: str
type: str # 'bullish' or 'bearish'
timeframe: str # 'short', 'medium', 'long'
rsi_change: float
price_change: float
confidence: float
current_rsi: float
prev_rsi: float
current_price: float
prev_price: float
@dataclass
class TrendSignal:
"""趋势信号数据结构"""
date: str
signal_type: str # 'buy' or 'sell'
price: float
trend_value: float
class IndicatorsStorage:
"""技术指标存储管理器 - 基于SQLite数据库"""
def __init__(self, cache_dir: str = "cache"):
self.cache = StockDataCache(cache_dir)
def calculate_and_store_indicators(self, df: pd.DataFrame, stock_name: str, symbol: str = None) -> Dict[str, Any]:
"""计算并存储所有技术指标"""
# 如果没有提供symbol,使用stock_name作为symbol
if symbol is None:
symbol = stock_name
# 确保数据副本,避免修改原始数据
data = df.copy()
# 计算技术指标
indicators_data = self._calculate_all_indicators(data)
# 使用增强后的DataFrame(包含所有技术指标)
enhanced_df = indicators_data['dataframe']
# 计算RSI背离信号
rsi_divergences = self._calculate_rsi_divergences(enhanced_df, indicators_data['rsi'])
# 计算趋势信号(使用包含trend列的DataFrame)
trend_signals = self._calculate_trend_signals(enhanced_df)
# 存储到数据库
self.cache.save_technical_indicators(symbol, stock_name, indicators_data['indicators'])
self.cache.save_rsi_divergences(symbol, stock_name, rsi_divergences)
self.cache.save_trend_signals(symbol, stock_name, trend_signals)
# 存储数据
storage_result = {
'indicators': indicators_data,
'rsi_divergences': rsi_divergences,
'trend_signals': trend_signals,
'enhanced_dataframe': enhanced_df, # 直接包含增强后的DataFrame
'stock_name': stock_name,
'symbol': symbol,
'calculation_time': datetime.now().isoformat()
}
return storage_result
def _calculate_all_indicators(self, df: pd.DataFrame) -> Dict[str, List[TechnicalIndicators]]:
"""计算所有技术指标"""
# 计算 SuperTrend 指标(已包含 super_trend, upper_band, lower_band, trend)
df = calculate_supertrend(df, lookback_periods=14, multiplier=2)
# 计算 RSI14 指标
df['rsi14'] = calculate_rsi(df, period=14)
df['rsi'] = df['rsi14'] # 添加 rsi 别名以保持向后兼容
# 计算 MA10
df['ma10'] = df['收盘'].rolling(window=10).mean()
# 计算量比(当日成交量/前5个交易日平均成交量)
df['vol_ratio'] = df['成交量'] / df['成交量'].shift(1).rolling(window=5).mean()
# 计算成交量指标(3 种特别量柱)
df = calculate_volume_indicators(df)
# 处理日涨幅数据:优先使用 akshare 的涨跌幅,否则计算
if '涨跌幅' in df.columns:
df['日涨幅'] = df['涨跌幅'] # 直接使用 akshare 的涨跌幅字段
elif '日涨幅' not in df.columns:
df = df.sort_values('日期').reset_index(drop=True)
df['日涨幅'] = df['收盘'].pct_change() * 100
indicators_list = []
rsi_values = []
for _, row in df.iterrows():
date_str = row['日期'].strftime('%Y-%m-%d') if hasattr(row['日期'], 'strftime') else str(row['日期'])
# 安全地获取列值,如果列不存在则使用默认值
def safe_get(column, default=None):
return row.get(column, default) if column in df.columns else default
# 安全地转换数值类型(处理 Decimal 类型)
def safe_float(value, decimal_places=2):
if pd.isnull(value) or value is None:
return None
try:
return round(float(value), decimal_places)
except (ValueError, TypeError):
return None
def safe_int(value, default=0):
if pd.isnull(value) or value is None:
return default
try:
return int(value)
except (ValueError, TypeError):
return default
def safe_bool(value, default=False):
if pd.isnull(value) or value is None:
return default
try:
return bool(value)
except (ValueError, TypeError):
return default
indicator = TechnicalIndicators(
date=date_str,
rsi14=safe_float(safe_get('rsi14')),
ma10=safe_float(safe_get('ma10')),
daily_change_pct=safe_float(safe_get('日涨幅'), 4),
upper_band=safe_float(safe_get('upper_band')),
lower_band=safe_float(safe_get('lower_band')),
trend=safe_int(safe_get('trend', 0)),
volume=safe_float(safe_get('成交量')),
vol_ratio=safe_float(safe_get('vol_ratio'), 2),
# 成交量指标增强
vol_20d_avg=safe_float(safe_get('vol_20d_avg')),
vol_20d_max=safe_float(safe_get('vol_20d_max')),
vol_50d_min=safe_float(safe_get('vol_50d_min')),
is_high_vol_bar=safe_bool(safe_get('is_high_vol_bar')),
is_sky_vol_bar=safe_bool(safe_get('is_sky_vol_bar')),
is_low_vol_bar=safe_bool(safe_get('is_low_vol_bar')),
near_20d_high=safe_bool(safe_get('near_20d_high')),
price_condition=safe_bool(safe_get('price_condition'))
)
indicators_list.append(indicator)
rsi_values.append(safe_get('rsi14'))
return {
'indicators': indicators_list,
'rsi': df['rsi'], # 使用 rsi 列(与 rsi14 相同)用于背离计算
'dataframe': df # 保留完整DataFrame
}
def _calculate_rsi_divergences(self, df: pd.DataFrame, rsi: pd.Series) -> List[RSIDivergence]:
"""计算 RSI 背离信号"""
divergences_df = detect_rsi_divergence(df, rsi)
divergences_list = []
if not divergences_df.empty:
for _, row in divergences_df.iterrows():
date_str = row['date'].strftime('%Y-%m-%d') if hasattr(row['date'], 'strftime') else str(row['date'])
prev_date_str = row['prev_date'].strftime('%Y-%m-%d') if hasattr(row['prev_date'], 'strftime') else str(row['prev_date'])
divergence = RSIDivergence(
date=date_str,
prev_date=prev_date_str,
type=row['type'],
timeframe=row['timeframe'],
rsi_change=round(float(row['rsi_change']), 2),
price_change=round(float(row['price_change']), 2),
confidence=round(float(row['confidence']), 2),
current_rsi=round(float(row['current_rsi']), 2),
prev_rsi=round(float(row['prev_rsi']), 2),
current_price=round(float(row['current_price']), 2),
prev_price=round(float(row['prev_price']), 2)
)
divergences_list.append(divergence)
return divergences_list
def _calculate_trend_signals(self, df: pd.DataFrame) -> List[TrendSignal]:
"""计算趋势变化信号"""
buy_positions, sell_positions = get_trend_signals(df)
signals_list = []
# 添加买入信号
for pos in buy_positions:
if pos < len(df):
row = df.iloc[pos]
date_str = row['日期'].strftime('%Y-%m-%d') if hasattr(row['日期'], 'strftime') else str(row['日期'])
signal = TrendSignal(
date=date_str,
signal_type='buy',
price=round(float(row['收盘']), 2),
trend_value=round(float(row['super_trend']), 2) if pd.notnull(row['super_trend']) else 0.0
)
signals_list.append(signal)
# 添加卖出信号
for pos in sell_positions:
if pos < len(df):
row = df.iloc[pos]
date_str = row['日期'].strftime('%Y-%m-%d') if hasattr(row['日期'], 'strftime') else str(row['日期'])
signal = TrendSignal(
date=date_str,
signal_type='sell',
price=round(float(row['收盘']), 2),
trend_value=round(float(row['super_trend']), 2) if pd.notnull(row['super_trend']) else 0.0
)
signals_list.append(signal)
# 按日期排序
signals_list.sort(key=lambda x: x.date)
return signals_list
def get_latest_indicators(self, stock_name: str) -> Optional[Dict[str, Any]]:
"""获取最新的技术指标摘要"""
return self.cache.get_latest_indicators(stock_name)
def export_to_dataframe(self, stock_name: str) -> Optional[pd.DataFrame]:
"""将指标数据导出为 DataFrame"""
return self.cache.get_indicators_dataframe(stock_name)
def enhance_analysis_with_indicators(df: pd.DataFrame, stock_name: str, symbol: str = None) -> Dict[str, Any]:
"""为 analysis.py 提供的便捷函数:计算并返回增强的指标数据"""
storage = IndicatorsStorage()
result = storage.calculate_and_store_indicators(df, stock_name, symbol)
# 直接使用存储结果中的增强DataFrame
enhanced_df = result['enhanced_dataframe'].copy()
return {
'enhanced_dataframe': enhanced_df,
'indicators_summary': storage.get_latest_indicators(symbol or stock_name),
'storage_result': result
}