Skip to content

Commit 8f97292

Browse files
committed
refactor(ptrade): 优化复权计算逻辑并改进除权事件处理
- 移除前复权计算中的多余np.round调用,保持数值精度 - 在后复权处理中移除不必要的四舍五入操作 - 优化除权事件检查逻辑,直接使用整数日期索引进行匹配 - 避免重复的DataFrame列查找和条件判断,提升性能
1 parent 64c844b commit 8f97292

2 files changed

Lines changed: 13 additions & 14 deletions

File tree

src/simtradelab/ptrade/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def _apply_adj_factors(self, stock_df: pd.DataFrame, stock: str, fq: str) -> pd.
178178
continue
179179
if fq == 'pre':
180180
# 前复权: (未复权价 - adj_b) / adj_a
181-
adjusted_df.loc[common_idx, col] = np.round(
182-
(adjusted_df.loc[common_idx, col] - adj_b) / adj_a, 2
181+
adjusted_df.loc[common_idx, col] = (
182+
(adjusted_df.loc[common_idx, col] - adj_b) / adj_a
183183
)
184184
else:
185185
# 后复权: adj_a * 未复权价 + adj_b
@@ -750,7 +750,7 @@ def get_history(self, count: int, frequency: str = '1d', field: str | list[str]
750750
if needs_adj_dypre:
751751
stock_result[field_name] = np.round(pre_price * adj_a_base + adj_b_base, 2)
752752
else:
753-
stock_result[field_name] = np.round(pre_price, 2)
753+
stock_result[field_name] = pre_price
754754
else:
755755
stock_result[field_name] = aligned_df[field_name].values
756756
# 后复权处理: 后复权价 = adj_a * 未复权价 + adj_b

src/simtradelab/ptrade/strategy_engine.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -472,17 +472,16 @@ def _process_dividend_events(self, current_date):
472472
# 检查除权事件(送股/配股)
473473
exrights_df = self.api.data_context.exrights_dict.get(stock_code)
474474
if exrights_df is not None and not exrights_df.empty:
475-
if 'date' in exrights_df.columns:
476-
match = exrights_df[exrights_df['date'] == current_date]
477-
if not match.empty:
478-
event = match.iloc[0]
479-
allotted = float(event.get('allotted_ps', 0) or 0)
480-
if allotted > 0:
481-
new_amount = int(original_amount * (1 + allotted))
482-
position.amount = new_amount
483-
position.enable_amount = new_amount
484-
position.cost_basis /= (1 + allotted)
485-
self.context.portfolio._invalidate_cache()
475+
date_int = int(date_str)
476+
if date_int in exrights_df.index:
477+
event = exrights_df.loc[date_int]
478+
allotted = float(event.get('allotted_ps', 0) or 0)
479+
if allotted > 0:
480+
new_amount = int(original_amount * (1 + allotted))
481+
position.amount = new_amount
482+
position.enable_amount = new_amount
483+
position.cost_basis /= (1 + allotted)
484+
self.context.portfolio._invalidate_cache()
486485

487486
# 现金分红(按登记日股数计算)
488487
if stock_code not in self.api.data_context.dividend_cache:

0 commit comments

Comments
 (0)