[Add]run ga optimization in CtaBacktester
This commit is contained in:
parent
4cd84b45a5
commit
a11e6be3ce
1
.flake8
1
.flake8
@ -5,3 +5,4 @@ ignore =
|
||||
W503 line break before binary operator
|
||||
W293 blank line contains whitespace
|
||||
W291 trailing whitespace
|
||||
W391 blank line at end of file
|
||||
|
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -16,7 +16,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -54,11 +54,62 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2019-05-03 16:19:04.193703\t开始运行遗传算法,每代族群总数:11, 优良品种筛选个数:8,迭代次数:30,交叉概率:0.95,突变概率:0.050000000000000044\n",
|
||||
"gen\tnevals\tmean \tstd \tmin \tmax \n",
|
||||
"0 \t11 \t[0.58423524]\t[0.30377007]\t[0.13231977]\t[1.2382818]\n",
|
||||
"1 \t11 \t[0.90248989]\t[0.15747112]\t[0.68707859]\t[1.2382818]\n",
|
||||
"2 \t11 \t[1.09406088]\t[0.18860523]\t[0.86284921]\t[1.46762684]\n",
|
||||
"3 \t11 \t[1.21413386]\t[0.12138014]\t[1.02072108]\t[1.46762684]\n",
|
||||
"4 \t11 \t[1.29561806]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
|
||||
"5 \t11 \t[1.41029058]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
|
||||
"6 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"7 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"8 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"9 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"10 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"11 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"12 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"13 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"14 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"15 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"16 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"17 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"18 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"19 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"20 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"21 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"22 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"23 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"24 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"25 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"26 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"27 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"28 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"29 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"30 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
|
||||
"2019-05-03 16:19:58.256354\t遗传算法优化完成,耗时54秒\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[({'atr_length': 38, 'atr_ma_length': 25}, 1.4676268402266743)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"setting = OptimizationSetting()\n",
|
||||
"setting.set_target(\"sharpe_ratio\")\n",
|
||||
|
@ -222,20 +222,25 @@ class BacktesterEngine(BaseEngine):
|
||||
return strategy_class.get_class_parameters()
|
||||
|
||||
def run_optimization(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
optimization_setting: OptimizationSetting):
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
optimization_setting: OptimizationSetting,
|
||||
use_ga: bool
|
||||
):
|
||||
""""""
|
||||
self.write_log("开始多进程参数优化")
|
||||
if use_ga:
|
||||
self.write_log("开始遗传算法参数优化")
|
||||
else:
|
||||
self.write_log("开始多进程参数优化")
|
||||
|
||||
self.result_values = None
|
||||
|
||||
@ -260,10 +265,16 @@ class BacktesterEngine(BaseEngine):
|
||||
{}
|
||||
)
|
||||
|
||||
self.result_values = engine.run_optimization(
|
||||
optimization_setting,
|
||||
output=False
|
||||
)
|
||||
if use_ga:
|
||||
self.result_values = engine.run_ga_optimization(
|
||||
optimization_setting,
|
||||
output=False
|
||||
)
|
||||
else:
|
||||
self.result_values = engine.run_optimization(
|
||||
optimization_setting,
|
||||
output=False
|
||||
)
|
||||
|
||||
# Clear thread object handler.
|
||||
self.thread = None
|
||||
@ -285,7 +296,8 @@ class BacktesterEngine(BaseEngine):
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
optimization_setting: OptimizationSetting
|
||||
optimization_setting: OptimizationSetting,
|
||||
use_ga: bool
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有任务在运行中,请等待完成")
|
||||
@ -305,7 +317,8 @@ class BacktesterEngine(BaseEngine):
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
optimization_setting
|
||||
optimization_setting,
|
||||
use_ga
|
||||
)
|
||||
)
|
||||
self.thread.start()
|
||||
|
@ -240,7 +240,7 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
if i != dialog.Accepted:
|
||||
return
|
||||
|
||||
optimization_setting = dialog.get_setting()
|
||||
optimization_setting, use_ga = dialog.get_setting()
|
||||
self.target_display = dialog.target_display
|
||||
|
||||
self.backtester_engine.start_optimization(
|
||||
@ -254,7 +254,8 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
optimization_setting
|
||||
optimization_setting,
|
||||
use_ga
|
||||
)
|
||||
|
||||
self.result_button.setEnabled(False)
|
||||
@ -592,6 +593,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
|
||||
self.edits = {}
|
||||
|
||||
self.optimization_setting = None
|
||||
self.use_ga = False
|
||||
|
||||
self.init_ui()
|
||||
|
||||
@ -642,12 +644,27 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
|
||||
|
||||
row += 1
|
||||
|
||||
button = QtWidgets.QPushButton("确定")
|
||||
button.clicked.connect(self.generate_setting)
|
||||
grid.addWidget(button, row, 0, 1, 4)
|
||||
parallel_button = QtWidgets.QPushButton("多进程优化")
|
||||
parallel_button.clicked.connect(self.generate_parallel_setting)
|
||||
grid.addWidget(parallel_button, row, 0, 1, 4)
|
||||
|
||||
row += 1
|
||||
ga_button = QtWidgets.QPushButton("遗传算法优化")
|
||||
ga_button.clicked.connect(self.generate_ga_setting)
|
||||
grid.addWidget(ga_button, row, 0, 1, 4)
|
||||
|
||||
self.setLayout(grid)
|
||||
|
||||
def generate_ga_setting(self):
|
||||
""""""
|
||||
self.use_ga = True
|
||||
self.generate_setting()
|
||||
|
||||
def generate_parallel_setting(self):
|
||||
""""""
|
||||
self.use_ga = False
|
||||
self.generate_setting()
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
self.optimization_setting = OptimizationSetting()
|
||||
@ -676,7 +693,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
return self.optimization_setting
|
||||
return self.optimization_setting, self.use_ga
|
||||
|
||||
|
||||
class OptimizationResultMonitor(QtWidgets.QDialog):
|
||||
|
@ -601,8 +601,12 @@ class BacktestingEngine:
|
||||
# toolbox.register("map", pool.map)
|
||||
|
||||
# Run ga optimization
|
||||
msg = "开始运行遗传算法,每代族群总数:%s, 优良品种筛选个数:%s,迭代次数:%s,交叉概率:%s,突变概率:%s" % (pop_size, mu, ngen, cxpb, mutpb)
|
||||
self.output(msg)
|
||||
self.output(f"参数优化空间:{total_size}")
|
||||
self.output(f"每代族群总数:{pop_size}")
|
||||
self.output(f"优良筛选个数:{mu}")
|
||||
self.output(f"迭代次数:{ngen}")
|
||||
self.output(f"交叉概率:{cxpb:.0%}")
|
||||
self.output(f"突变概率:{mutpb:.0%}")
|
||||
|
||||
start = time()
|
||||
|
||||
@ -630,7 +634,7 @@ class BacktestingEngine:
|
||||
for parameter_values in hof:
|
||||
setting = dict(zip(parameter_keys, parameter_values))
|
||||
target_value = ga_optimize(parameter_values)[0]
|
||||
results.append((setting, target_value))
|
||||
results.append((setting, target_value, {}))
|
||||
|
||||
return results
|
||||
|
||||
@ -1059,12 +1063,13 @@ def optimize(
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
end: datetime,
|
||||
mode: BacktestingMode,
|
||||
mode: BacktestingMode
|
||||
):
|
||||
"""
|
||||
Function for running in multiprocessing.pool
|
||||
"""
|
||||
engine = BacktestingEngine()
|
||||
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
|
Loading…
Reference in New Issue
Block a user