|
1 | | -# -*- coding: utf-8 -*- |
2 | | -# SPDX-License-Identifier: AGPL-3.0-or-later |
3 | | -# Copyright (c) 2025 Kay |
4 | | -# |
5 | | -# This file is part of SimTradeLab, dual-licensed under AGPL-3.0 and a |
6 | | -# commercial license. See LICENSE-COMMERCIAL.md or contact kayou@duck.com |
7 | | -# |
8 | | -""" |
9 | | -回测配置类 |
10 | | -""" |
11 | | - |
12 | | - |
13 | | -from __future__ import annotations |
14 | | - |
15 | | -from datetime import datetime |
16 | | -from pathlib import Path |
17 | | -from typing import Optional |
18 | | -import pandas as pd |
19 | | -from pydantic import BaseModel, Field, field_validator, model_validator |
20 | | - |
21 | | - |
22 | | -def _default_data_path(): |
23 | | - """获取默认数据路径""" |
24 | | - from ..utils.paths import DATA_PATH |
25 | | - return str(DATA_PATH) |
26 | | - |
27 | | - |
28 | | -def _default_strategies_path(): |
29 | | - """获取默认策略路径""" |
30 | | - from ..utils.paths import STRATEGIES_PATH |
31 | | - return str(STRATEGIES_PATH) |
32 | | - |
33 | | - |
34 | | -class BacktestConfig(BaseModel): |
35 | | - """回测配置参数""" |
36 | | - |
37 | | - strategy_name: str |
38 | | - start_date: str | pd.Timestamp |
39 | | - end_date: str | pd.Timestamp |
40 | | - data_path: str = Field(default_factory=_default_data_path) |
41 | | - strategies_path: str = Field(default_factory=_default_strategies_path) |
42 | | - initial_capital: float = Field(default=100000.0, gt=0, description="初始资金必须大于0") |
43 | | - use_data_server: bool = True |
44 | | - |
45 | | - # 回测频率配置 |
46 | | - frequency: str = Field(default='1d', description="回测频率: '1d'日线, '1m'分钟线") |
47 | | - |
48 | | - # 基准配置 |
49 | | - benchmark_code: str = Field(default='000300.SS', description="基准代码") |
50 | | - |
51 | | - # 性能优化配置 |
52 | | - enable_multiprocessing: bool = True |
53 | | - num_workers: Optional[int] = Field(default=None, ge=1, description="多进程worker数量") |
54 | | - enable_charts: bool = True |
55 | | - enable_logging: bool = True |
56 | | - enable_export: bool = False |
57 | | - |
58 | | - # 沙箱模式:True=限制import和builtins(Ptrade兼容),False=本地开发无限制 |
59 | | - sandbox: bool = True |
60 | | - |
61 | | - # T+1交易限制:True=A股模式(当日买入不可卖),False=T+0模式(ETF/美股) |
62 | | - t_plus_1: bool = True |
63 | | - |
64 | | - # 优化模式:跳过策略验证/数据分析/日志配置(由优化器管理) |
65 | | - optimization_mode: bool = False |
66 | | - |
67 | | - model_config = {"arbitrary_types_allowed": True} |
68 | | - |
69 | | - @field_validator('start_date', 'end_date', mode='before') |
70 | | - @classmethod |
71 | | - def convert_to_timestamp(cls, v) -> pd.Timestamp: |
72 | | - """转换日期为pd.Timestamp""" |
73 | | - if isinstance(v, pd.Timestamp): |
74 | | - return v |
75 | | - return pd.Timestamp(v) |
76 | | - |
77 | | - @model_validator(mode='after') |
78 | | - def validate_date_range(self): |
79 | | - """验证日期范围 |
80 | | -
|
81 | | - 此时start_date和end_date已被field_validator转换为pd.Timestamp |
82 | | - """ |
83 | | - if self.start_date >= self.end_date: # type: ignore |
84 | | - raise ValueError("start_date必须早于end_date") |
85 | | - return self |
86 | | - |
87 | | - @property |
88 | | - def strategy_path(self) -> str: |
89 | | - """策略文件完整路径""" |
90 | | - return str(Path(self.strategies_path) / self.strategy_name / 'backtest.py') |
91 | | - |
92 | | - @property |
93 | | - def log_dir(self) -> str: |
94 | | - """日志目录""" |
95 | | - return str(Path(self.strategies_path) / self.strategy_name / 'stats') |
96 | | - |
97 | | - def get_log_filename(self) -> str: |
98 | | - """生成日志文件名""" |
99 | | - name = 'backtest_{}_{}_{}.log'.format( |
100 | | - self.start_date.strftime("%y%m%d"), # type: ignore |
101 | | - self.end_date.strftime("%y%m%d"), # type: ignore |
102 | | - datetime.now().strftime("%y%m%d_%H%M%S")) |
103 | | - return str(Path(self.log_dir) / name) |
104 | | - |
105 | | - def get_chart_filename(self) -> str: |
106 | | - """生成图表文件名""" |
107 | | - name = 'backtest_{}_{}_{}.png'.format( |
108 | | - self.start_date.strftime("%y%m%d"), # type: ignore |
109 | | - self.end_date.strftime("%y%m%d"), # type: ignore |
110 | | - datetime.now().strftime("%y%m%d_%H%M%S")) |
111 | | - return str(Path(self.log_dir) / name) |
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# SPDX-License-Identifier: AGPL-3.0-or-later |
| 3 | +# Copyright (c) 2025 Kay |
| 4 | +# |
| 5 | +# This file is part of SimTradeLab, dual-licensed under AGPL-3.0 and a |
| 6 | +# commercial license. See LICENSE-COMMERCIAL.md or contact kayou@duck.com |
| 7 | +# |
| 8 | +""" |
| 9 | +回测配置类 |
| 10 | +""" |
| 11 | + |
| 12 | + |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from datetime import datetime |
| 16 | +from pathlib import Path |
| 17 | +from typing import Optional |
| 18 | +import pandas as pd |
| 19 | +from pydantic import BaseModel, Field, field_validator, model_validator |
| 20 | + |
| 21 | + |
| 22 | +def _default_data_path(): |
| 23 | + """获取默认数据路径""" |
| 24 | + from ..utils.paths import DATA_PATH |
| 25 | + return str(DATA_PATH) |
| 26 | + |
| 27 | + |
| 28 | +def _default_strategies_path(): |
| 29 | + """获取默认策略路径""" |
| 30 | + from ..utils.paths import STRATEGIES_PATH |
| 31 | + return str(STRATEGIES_PATH) |
| 32 | + |
| 33 | + |
| 34 | +class BacktestConfig(BaseModel): |
| 35 | + """回测配置参数""" |
| 36 | + |
| 37 | + strategy_name: str |
| 38 | + start_date: str | pd.Timestamp |
| 39 | + end_date: str | pd.Timestamp |
| 40 | + data_path: str = Field(default_factory=_default_data_path) |
| 41 | + strategies_path: str = Field(default_factory=_default_strategies_path) |
| 42 | + initial_capital: float = Field(default=100000.0, gt=0, description="初始资金必须大于0") |
| 43 | + use_data_server: bool = True |
| 44 | + |
| 45 | + # 回测频率配置 |
| 46 | + frequency: str = Field(default='1d', description="回测频率: '1d'日线, '1m'分钟线") |
| 47 | + |
| 48 | + # 基准配置 |
| 49 | + benchmark_code: str = Field(default='', description="基准代码,空串时使用市场默认基准") |
| 50 | + |
| 51 | + # 性能优化配置 |
| 52 | + enable_multiprocessing: bool = True |
| 53 | + num_workers: Optional[int] = Field(default=None, ge=1, description="多进程worker数量") |
| 54 | + enable_charts: bool = True |
| 55 | + enable_logging: bool = True |
| 56 | + enable_export: bool = False |
| 57 | + |
| 58 | + # 沙箱模式:True=限制import和builtins(Ptrade兼容),False=本地开发无限制 |
| 59 | + sandbox: bool = True |
| 60 | + |
| 61 | + # 市场选择: CN=A股, US=美股 |
| 62 | + market: str = Field(default="CN", description="市场代码") |
| 63 | + |
| 64 | + # T+1 覆盖:None=使用市场默认(CN=True, US=False),显式值覆盖市场默认 |
| 65 | + t_plus_1: Optional[bool] = None |
| 66 | + |
| 67 | + # 优化模式:跳过策略验证/数据分析/日志配置(由优化器管理) |
| 68 | + optimization_mode: bool = False |
| 69 | + |
| 70 | + # 语言:zh=中文, en=英文 |
| 71 | + locale: str = Field(default="zh", description="语言") |
| 72 | + |
| 73 | + model_config = {"arbitrary_types_allowed": True} |
| 74 | + |
| 75 | + @field_validator('start_date', 'end_date', mode='before') |
| 76 | + @classmethod |
| 77 | + def convert_to_timestamp(cls, v) -> pd.Timestamp: |
| 78 | + """转换日期为pd.Timestamp""" |
| 79 | + if isinstance(v, pd.Timestamp): |
| 80 | + return v |
| 81 | + return pd.Timestamp(v) |
| 82 | + |
| 83 | + @model_validator(mode='after') |
| 84 | + def validate_date_range(self): |
| 85 | + """验证日期范围 |
| 86 | +
|
| 87 | + 此时start_date和end_date已被field_validator转换为pd.Timestamp |
| 88 | + """ |
| 89 | + if self.start_date >= self.end_date: # type: ignore |
| 90 | + raise ValueError("start_date必须早于end_date") |
| 91 | + return self |
| 92 | + |
| 93 | + @property |
| 94 | + def strategy_path(self) -> str: |
| 95 | + """策略文件完整路径""" |
| 96 | + return str(Path(self.strategies_path) / self.strategy_name / 'backtest.py') |
| 97 | + |
| 98 | + @property |
| 99 | + def log_dir(self) -> str: |
| 100 | + """日志目录""" |
| 101 | + return str(Path(self.strategies_path) / self.strategy_name / 'stats') |
| 102 | + |
| 103 | + def get_log_filename(self) -> str: |
| 104 | + """生成日志文件名""" |
| 105 | + name = 'backtest_{}_{}_{}.log'.format( |
| 106 | + self.start_date.strftime("%y%m%d"), # type: ignore |
| 107 | + self.end_date.strftime("%y%m%d"), # type: ignore |
| 108 | + datetime.now().strftime("%y%m%d_%H%M%S")) |
| 109 | + return str(Path(self.log_dir) / name) |
| 110 | + |
| 111 | + def get_chart_filename(self) -> str: |
| 112 | + """生成图表文件名""" |
| 113 | + name = 'backtest_{}_{}_{}.png'.format( |
| 114 | + self.start_date.strftime("%y%m%d"), # type: ignore |
| 115 | + self.end_date.strftime("%y%m%d"), # type: ignore |
| 116 | + datetime.now().strftime("%y%m%d_%H%M%S")) |
| 117 | + return str(Path(self.log_dir) / name) |
0 commit comments